From 05ba8279073d229cb825ba00740cbbabda7b1203 Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 9 Sep 2022 11:48:34 -0400
Subject: [PATCH 1/5] ARROW-17661: [C++][Python] Add Flight SQL ADBC driver
Co-authored-by: Sutou Kouhei
---
ci/scripts/python_build.sh | 1 +
ci/scripts/python_test.sh | 4 +
cpp/cmake_modules/ThirdpartyToolchain.cmake | 120 +
cpp/src/arrow/c/adbc_internal.h | 1207 ++++++++++
cpp/src/arrow/flight/sql/CMakeLists.txt | 13 +-
cpp/src/arrow/flight/sql/adbc_driver.cc | 2062 +++++++++++++++++
.../arrow/flight/sql/adbc_driver_internal.cc | 191 ++
.../arrow/flight/sql/adbc_driver_internal.h | 198 ++
cpp/src/arrow/flight/sql/adbc_driver_test.cc | 1057 +++++++++
cpp/src/arrow/flight/sql/client.cc | 1 +
cpp/src/arrow/flight/sql/client.h | 14 +
cpp/src/arrow/symbols.map | 3 +
cpp/thirdparty/versions.txt | 6 +
python/CMakeLists.txt | 29 +
python/pyarrow/_flight_sql.pyx | 29 +
python/pyarrow/conftest.py | 8 +
python/pyarrow/flight_sql.py | 153 ++
.../pyarrow/includes/libarrow_flight_sql.pxd | 26 +
python/pyarrow/tests/test_flight_sql.py | 67 +
python/setup.py | 6 +
20 files changed, 5193 insertions(+), 2 deletions(-)
create mode 100644 cpp/src/arrow/c/adbc_internal.h
create mode 100644 cpp/src/arrow/flight/sql/adbc_driver.cc
create mode 100644 cpp/src/arrow/flight/sql/adbc_driver_internal.cc
create mode 100644 cpp/src/arrow/flight/sql/adbc_driver_internal.h
create mode 100644 cpp/src/arrow/flight/sql/adbc_driver_test.cc
create mode 100644 python/pyarrow/_flight_sql.pyx
create mode 100644 python/pyarrow/flight_sql.py
create mode 100644 python/pyarrow/includes/libarrow_flight_sql.pxd
create mode 100644 python/pyarrow/tests/test_flight_sql.py
diff --git a/ci/scripts/python_build.sh b/ci/scripts/python_build.sh
index cfac68bd6ec..1275e6405f3 100755
--- a/ci/scripts/python_build.sh
+++ b/ci/scripts/python_build.sh
@@ -57,6 +57,7 @@ export PYARROW_BUILD_TYPE=${CMAKE_BUILD_TYPE:-debug}
export PYARROW_WITH_CUDA=${ARROW_CUDA:-OFF}
export PYARROW_WITH_DATASET=${ARROW_DATASET:-ON}
export PYARROW_WITH_FLIGHT=${ARROW_FLIGHT:-OFF}
+export PYARROW_WITH_FLIGHT_SQL=${ARROW_FLIGHT_SQL:-OFF}
export PYARROW_WITH_GANDIVA=${ARROW_GANDIVA:-OFF}
export PYARROW_WITH_GCS=${ARROW_GCS:-OFF}
export PYARROW_WITH_HDFS=${ARROW_HDFS:-ON}
diff --git a/ci/scripts/python_test.sh b/ci/scripts/python_test.sh
index 2d5bd5dd9ff..d2e1cb206d1 100755
--- a/ci/scripts/python_test.sh
+++ b/ci/scripts/python_test.sh
@@ -39,6 +39,9 @@ export ARROW_DEBUG_MEMORY_POOL=trap
: ${PYARROW_TEST_CUDA:=${ARROW_CUDA:-ON}}
: ${PYARROW_TEST_DATASET:=${ARROW_DATASET:-ON}}
: ${PYARROW_TEST_FLIGHT:=${ARROW_FLIGHT:-ON}}
+# Flight SQL must be enabled explicitly due to optional dependency on
+# ADBC driver manager
+: ${PYARROW_TEST_FLIGHT_SQL:=OFF}
: ${PYARROW_TEST_GANDIVA:=${ARROW_GANDIVA:-ON}}
: ${PYARROW_TEST_GCS:=${ARROW_GCS:-ON}}
: ${PYARROW_TEST_HDFS:=${ARROW_HDFS:-ON}}
@@ -49,6 +52,7 @@ export ARROW_DEBUG_MEMORY_POOL=trap
export PYARROW_TEST_CUDA
export PYARROW_TEST_DATASET
export PYARROW_TEST_FLIGHT
+export PYARROW_TEST_FLIGHT_SQL
export PYARROW_TEST_GANDIVA
export PYARROW_TEST_GCS
export PYARROW_TEST_HDFS
diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake
index 3eda538fb2e..365aa936585 100644
--- a/cpp/cmake_modules/ThirdpartyToolchain.cmake
+++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake
@@ -45,6 +45,7 @@ set(ARROW_RE2_LINKAGE
set(ARROW_THIRDPARTY_DEPENDENCIES
absl
+ AdbcValidation
AWSSDK
benchmark
Boost
@@ -59,6 +60,7 @@ set(ARROW_THIRDPARTY_DEPENDENCIES
jemalloc
LLVM
lz4
+ nanoarrow
nlohmann_json
opentelemetry-cpp
ORC
@@ -149,6 +151,8 @@ endforeach()
macro(build_dependency DEPENDENCY_NAME)
if("${DEPENDENCY_NAME}" STREQUAL "absl")
build_absl()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "AdbcValidation")
+ build_adbc_validation()
elseif("${DEPENDENCY_NAME}" STREQUAL "AWSSDK")
build_awssdk()
elseif("${DEPENDENCY_NAME}" STREQUAL "benchmark")
@@ -175,6 +179,8 @@ macro(build_dependency DEPENDENCY_NAME)
build_jemalloc()
elseif("${DEPENDENCY_NAME}" STREQUAL "lz4")
build_lz4()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "nanoarrow")
+ build_nanoarrow()
elseif("${DEPENDENCY_NAME}" STREQUAL "nlohmann_json")
build_nlohmann_json()
elseif("${DEPENDENCY_NAME}" STREQUAL "opentelemetry-cpp")
@@ -345,6 +351,11 @@ if(ARROW_FLIGHT)
set(ARROW_WITH_GRPC ON)
endif()
+if(ARROW_FLIGHT_SQL AND ARROW_BUILD_TESTS)
+ set(ARROW_WITH_ADBC_VALIDATION ON)
+ set(ARROW_WITH_NANOARROW ON)
+endif()
+
if(ARROW_WITH_GRPC)
set(ARROW_WITH_RE2 ON)
set(ARROW_WITH_ZLIB ON)
@@ -431,6 +442,14 @@ else()
)
endif()
+if(DEFINED ENV{ARROW_ADBC_URL})
+ set(ADBC_SOURCE_URL "$ENV{ARROW_ADBC_URL}")
+else()
+ set_urls(ADBC_SOURCE_URL
+ "https://github.com/apache/arrow-adbc/archive/${ARROW_ADBC_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
if(DEFINED ENV{ARROW_AWS_C_COMMON_URL})
set(AWS_C_COMMON_SOURCE_URL "$ENV{ARROW_AWS_C_COMMON_URL}")
else()
@@ -577,6 +596,14 @@ else()
"${THIRDPARTY_MIRROR_URL}/mimalloc-${ARROW_MIMALLOC_BUILD_VERSION}.tar.gz")
endif()
+if(DEFINED ENV{ARROW_NANOARROW_URL})
+ set(NANOARROW_SOURCE_URL "$ENV{ARROW_NANOARROW_URL}")
+else()
+ set_urls(NANOARROW_SOURCE_URL
+ "https://github.com/apache/arrow-nanoarrow/archive/${ARROW_NANOARROW_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
if(DEFINED ENV{ARROW_NLOHMANN_JSON_URL})
set(NLOHMANN_JSON_SOURCE_URL "$ENV{ARROW_NLOHMANN_JSON_URL}")
else()
@@ -4584,6 +4611,99 @@ if(ARROW_WITH_OPENTELEMETRY)
message(STATUS "Found OpenTelemetry headers: ${OPENTELEMETRY_INCLUDE_DIR}")
endif()
+# ----------------------------------------------------------------------
+# nanoarrow (only used for tests)
+
+macro(build_nanoarrow)
+ message(STATUS "Building nanoarrow from source")
+ set(NANOARROW_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/nanoarrow_ep-install")
+ set(NANOARROW_INCLUDE_DIR "${NANOARROW_PREFIX}/include")
+ set(NANOARROW_LIB_DIR "lib")
+
+ set(NANOARROW_COMMON_CMAKE_ARGS
+ ${EP_COMMON_CMAKE_ARGS}
+ "-DCMAKE_INSTALL_LIBDIR=${NANOARROW_LIB_DIR}"
+ "-DCMAKE_INSTALL_PREFIX=${NANOARROW_PREFIX}"
+ "-DCMAKE_PREFIX_PATH=${NANOARROW_PREFIX}"
+ "-DCMAKE_UNITY_BUILD=ON")
+ set(NANOARROW_STATIC_LIBRARY
+ "${NANOARROW_PREFIX}/${NANOARROW_LIB_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}nanoarrow${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+
+ file(MAKE_DIRECTORY ${NANOARROW_INCLUDE_DIR})
+
+ add_library(nanoarrow STATIC IMPORTED)
+ set_target_properties(nanoarrow
+ PROPERTIES IMPORTED_LOCATION "${NANOARROW_STATIC_LIBRARY}"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${NANOARROW_INCLUDE_DIR}")
+ externalproject_add(nanoarrow_ep
+ ${EP_LOG_OPTIONS}
+ URL ${NANOARROW_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_NANOARROW_BUILD_SHA256_CHECKSUM}"
+ CMAKE_ARGS ${NANOARROW_COMMON_CMAKE_ARGS}
+ BUILD_BYPRODUCTS "${NANOARROW_STATIC_LIBRARY}")
+ set(NANOARROW_LINK_LIBRARIES nanoarrow)
+ add_dependencies(nanoarrow nanoarrow_ep)
+endmacro()
+
+if(ARROW_WITH_NANOARROW)
+ set(nanoarrow_SOURCE "AUTO")
+ resolve_dependency(nanoarrow HAVE_ALT FALSE)
+
+ message(STATUS "Found nanoarrow headers: ${NANOARROW_INCLUDE_DIR}")
+ message(STATUS "Found nanoarrow libraries: ${NANOARROW_LINK_LIBRARIES}")
+endif()
+
+# ----------------------------------------------------------------------
+# ADBC validation suite (only used for tests)
+
+macro(build_adbc_validation)
+ message(STATUS "Building ADBC validation suite from source")
+ set(ADBC_VALIDATION_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/adbcvalidation_ep-install")
+ set(ADBC_VALIDATION_INCLUDE_DIR "${ADBC_VALIDATION_PREFIX}/include")
+ set(ADBC_VALIDATION_LIB_DIR "lib")
+
+ set(ADBC_VALIDATION_COMMON_CMAKE_ARGS
+ ${EP_COMMON_CMAKE_ARGS}
+ "-DCMAKE_INSTALL_LIBDIR=${ADBC_VALIDATION_LIB_DIR}"
+ "-DCMAKE_INSTALL_PREFIX=${ADBC_VALIDATION_PREFIX}"
+ "-DCMAKE_PREFIX_PATH=${ADBC_VALIDATION_PREFIX}"
+ "-DCMAKE_UNITY_BUILD=ON")
+ set(ADBC_VALIDATION_STATIC_LIBRARY
+ "${ADBC_VALIDATION_PREFIX}/${ADBC_VALIDATION_LIB_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}adbc_validation${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+
+ file(MAKE_DIRECTORY ${ADBC_VALIDATION_INCLUDE_DIR})
+
+ add_library(AdbcValidation::adbc_validation STATIC IMPORTED)
+ set_target_properties(AdbcValidation::adbc_validation
+ PROPERTIES IMPORTED_LOCATION "${ADBC_VALIDATION_STATIC_LIBRARY}"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${ADBC_VALIDATION_INCLUDE_DIR}"
+ INTERFACE_LINK_LIBRARIES nanoarrow
+ GTest::gtest GTest::gmock)
+ externalproject_add(adbcvalidation_ep
+ ${EP_LOG_OPTIONS}
+ URL ${ADBC_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_ADBC_BUILD_SHA256_CHECKSUM}"
+ CMAKE_ARGS ${ADBC_VALIDATION_COMMON_CMAKE_ARGS}
+ BUILD_BYPRODUCTS "${ADBC_VALIDATION_STATIC_LIBRARY}"
+ SOURCE_SUBDIR c/validation)
+ add_dependencies(AdbcValidation::adbc_validation nanoarrow_ep)
+ set(ADBCVALIDATION_LINK_LIBRARIES AdbcValidation::adbc_validation)
+ add_dependencies(AdbcValidation::adbc_validation adbcvalidation_ep)
+endmacro()
+
+if(ARROW_WITH_ADBC_VALIDATION)
+ set(AdbcValidation_SOURCE "AUTO")
+ resolve_dependency(AdbcValidation HAVE_ALT FALSE)
+
+ message(STATUS "Found ADBC validation suite headers: ${ADBC_VALIDATION_INCLUDE_DIR}")
+ message(STATUS "Found ADBC validation suite libraries: ${ADBC_VALIDATION_LINK_LIBRARIES}"
+ )
+endif()
+
# ----------------------------------------------------------------------
# AWS SDK for C++
diff --git a/cpp/src/arrow/c/adbc_internal.h b/cpp/src/arrow/c/adbc_internal.h
new file mode 100644
index 00000000000..a1ff53441db
--- /dev/null
+++ b/cpp/src/arrow/c/adbc_internal.h
@@ -0,0 +1,1207 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// \file adbc.h ADBC: Arrow Database connectivity
+///
+/// An Arrow-based interface between applications and database
+/// drivers. ADBC aims to provide a vendor-independent API for SQL
+/// and Substrait-based database access that is targeted at
+/// analytics/OLAP use cases.
+///
+/// This API is intended to be implemented directly by drivers and
+/// used directly by client applications. To assist portability
+/// between different vendors, a "driver manager" library is also
+/// provided, which implements this same API, but dynamically loads
+/// drivers internally and forwards calls appropriately.
+///
+/// ADBC uses structs with free functions that operate on those
+/// structs to model objects.
+///
+/// In general, objects allow serialized access from multiple threads,
+/// but not concurrent access. Specific implementations may permit
+/// multiple threads.
+///
+/// \version 1.0.0
+
+#pragma once
+
+#include
+#include
+
+/// \defgroup Arrow C Data Interface
+/// Definitions for the C Data Interface/C Stream Interface.
+///
+/// See https://arrow.apache.org/docs/format/CDataInterface.html
+///
+/// @{
+
+//! @cond Doxygen_Suppress
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// Extra guard for versions of Arrow without the canonical guard
+#ifndef ARROW_FLAG_DICTIONARY_ORDERED
+
+#ifndef ARROW_C_DATA_INTERFACE
+#define ARROW_C_DATA_INTERFACE
+
+#define ARROW_FLAG_DICTIONARY_ORDERED 1
+#define ARROW_FLAG_NULLABLE 2
+#define ARROW_FLAG_MAP_KEYS_SORTED 4
+
+struct ArrowSchema {
+ // Array type description
+ const char* format;
+ const char* name;
+ const char* metadata;
+ int64_t flags;
+ int64_t n_children;
+ struct ArrowSchema** children;
+ struct ArrowSchema* dictionary;
+
+ // Release callback
+ void (*release)(struct ArrowSchema*);
+ // Opaque producer-specific data
+ void* private_data;
+};
+
+struct ArrowArray {
+ // Array data description
+ int64_t length;
+ int64_t null_count;
+ int64_t offset;
+ int64_t n_buffers;
+ int64_t n_children;
+ const void** buffers;
+ struct ArrowArray** children;
+ struct ArrowArray* dictionary;
+
+ // Release callback
+ void (*release)(struct ArrowArray*);
+ // Opaque producer-specific data
+ void* private_data;
+};
+
+#endif // ARROW_C_DATA_INTERFACE
+
+#ifndef ARROW_C_STREAM_INTERFACE
+#define ARROW_C_STREAM_INTERFACE
+
+struct ArrowArrayStream {
+ // Callback to get the stream type
+ // (will be the same for all arrays in the stream).
+ //
+ // Return value: 0 if successful, an `errno`-compatible error code otherwise.
+ //
+ // If successful, the ArrowSchema must be released independently from the stream.
+ int (*get_schema)(struct ArrowArrayStream*, struct ArrowSchema* out);
+
+ // Callback to get the next array
+ // (if no error and the array is released, the stream has ended)
+ //
+ // Return value: 0 if successful, an `errno`-compatible error code otherwise.
+ //
+ // If successful, the ArrowArray must be released independently from the stream.
+ int (*get_next)(struct ArrowArrayStream*, struct ArrowArray* out);
+
+ // Callback to get optional detailed error information.
+ // This must only be called if the last stream operation failed
+ // with a non-0 return code.
+ //
+ // Return value: pointer to a null-terminated character array describing
+ // the last error, or NULL if no description is available.
+ //
+ // The returned pointer is only valid until the next operation on this stream
+ // (including release).
+ const char* (*get_last_error)(struct ArrowArrayStream*);
+
+ // Release callback: release the stream's own resources.
+ // Note that arrays returned by `get_next` must be individually released.
+ void (*release)(struct ArrowArrayStream*);
+
+ // Opaque producer-specific data
+ void* private_data;
+};
+
+#endif // ARROW_C_STREAM_INTERFACE
+#endif // ARROW_FLAG_DICTIONARY_ORDERED
+
+//! @endcond
+
+/// @}
+
+#ifndef ADBC
+#define ADBC
+
+// Storage class macros for Windows
+// Allow overriding/aliasing with application-defined macros
+#if !defined(ADBC_EXPORT)
+#if defined(_WIN32)
+#if defined(ADBC_EXPORTING)
+#define ADBC_EXPORT __declspec(dllexport)
+#else
+#define ADBC_EXPORT __declspec(dllimport)
+#endif // defined(ADBC_EXPORTING)
+#else
+#define ADBC_EXPORT
+#endif // defined(_WIN32)
+#endif // !defined(ADBC_EXPORT)
+
+/// \defgroup adbc-error-handling Error Handling
+/// ADBC uses integer error codes to signal errors. To provide more
+/// detail about errors, functions may also return an AdbcError via an
+/// optional out parameter, which can be inspected. If provided, it is
+/// the responsibility of the caller to zero-initialize the AdbcError
+/// value.
+///
+/// @{
+
+/// \brief Error codes for operations that may fail.
+typedef uint8_t AdbcStatusCode;
+
+/// \brief No error.
+#define ADBC_STATUS_OK 0
+/// \brief An unknown error occurred.
+///
+/// May indicate a driver-side or database-side error.
+#define ADBC_STATUS_UNKNOWN 1
+/// \brief The operation is not implemented or supported.
+///
+/// May indicate a driver-side or database-side error.
+#define ADBC_STATUS_NOT_IMPLEMENTED 2
+/// \brief A requested resource was not found.
+///
+/// May indicate a driver-side or database-side error.
+#define ADBC_STATUS_NOT_FOUND 3
+/// \brief A requested resource already exists.
+///
+/// May indicate a driver-side or database-side error.
+#define ADBC_STATUS_ALREADY_EXISTS 4
+/// \brief The arguments are invalid, likely a programming error.
+///
+/// For instance, they may be of the wrong format, or out of range.
+///
+/// May indicate a driver-side or database-side error.
+#define ADBC_STATUS_INVALID_ARGUMENT 5
+/// \brief The preconditions for the operation are not met, likely a
+/// programming error.
+///
+/// For instance, the object may be uninitialized, or may have not
+/// been fully configured.
+///
+/// May indicate a driver-side or database-side error.
+#define ADBC_STATUS_INVALID_STATE 6
+/// \brief Invalid data was processed (not a programming error).
+///
+/// For instance, a division by zero may have occurred during query
+/// execution.
+///
+/// May indicate a database-side error only.
+#define ADBC_STATUS_INVALID_DATA 7
+/// \brief The database's integrity was affected.
+///
+/// For instance, a foreign key check may have failed, or a uniqueness
+/// constraint may have been violated.
+///
+/// May indicate a database-side error only.
+#define ADBC_STATUS_INTEGRITY 8
+/// \brief An error internal to the driver or database occurred.
+///
+/// May indicate a driver-side or database-side error.
+#define ADBC_STATUS_INTERNAL 9
+/// \brief An I/O error occurred.
+///
+/// For instance, a remote service may be unavailable.
+///
+/// May indicate a driver-side or database-side error.
+#define ADBC_STATUS_IO 10
+/// \brief The operation was cancelled, not due to a timeout.
+///
+/// May indicate a driver-side or database-side error.
+#define ADBC_STATUS_CANCELLED 11
+/// \brief The operation was cancelled due to a timeout.
+///
+/// May indicate a driver-side or database-side error.
+#define ADBC_STATUS_TIMEOUT 12
+/// \brief Authentication failed.
+///
+/// May indicate a database-side error only.
+#define ADBC_STATUS_UNAUTHENTICATED 13
+/// \brief The client is not authorized to perform the given operation.
+///
+/// May indicate a database-side error only.
+#define ADBC_STATUS_UNAUTHORIZED 14
+
+/// \brief A detailed error message for an operation.
+struct ADBC_EXPORT AdbcError {
+ /// \brief The error message.
+ char* message;
+
+ /// \brief A vendor-specific error code, if applicable.
+ int32_t vendor_code;
+
+ /// \brief A SQLSTATE error code, if provided, as defined by the
+ /// SQL:2003 standard. If not set, it should be set to
+ /// "\0\0\0\0\0".
+ char sqlstate[5];
+
+ /// \brief Release the contained error.
+ ///
+ /// Unlike other structures, this is an embedded callback to make it
+ /// easier for the driver manager and driver to cooperate.
+ void (*release)(struct AdbcError* error);
+};
+
+/// @}
+
+/// \defgroup adbc-constants Constants
+/// @{
+
+/// \brief ADBC revision 1.0.0.
+///
+/// When passed to an AdbcDriverInitFunc(), the driver parameter must
+/// point to an AdbcDriver.
+#define ADBC_VERSION_1_0_0 1000000
+
+/// \brief Canonical option value for enabling an option.
+///
+/// For use as the value in SetOption calls.
+#define ADBC_OPTION_VALUE_ENABLED "true"
+/// \brief Canonical option value for disabling an option.
+///
+/// For use as the value in SetOption calls.
+#define ADBC_OPTION_VALUE_DISABLED "false"
+
+/// \brief The database vendor/product name (e.g. the server name).
+/// (type: utf8).
+///
+/// \see AdbcConnectionGetInfo
+#define ADBC_INFO_VENDOR_NAME 0
+/// \brief The database vendor/product version (type: utf8).
+///
+/// \see AdbcConnectionGetInfo
+#define ADBC_INFO_VENDOR_VERSION 1
+/// \brief The database vendor/product Arrow library version (type:
+/// utf8).
+///
+/// \see AdbcConnectionGetInfo
+#define ADBC_INFO_VENDOR_ARROW_VERSION 2
+
+/// \brief The driver name (type: utf8).
+///
+/// \see AdbcConnectionGetInfo
+#define ADBC_INFO_DRIVER_NAME 100
+/// \brief The driver version (type: utf8).
+///
+/// \see AdbcConnectionGetInfo
+#define ADBC_INFO_DRIVER_VERSION 101
+/// \brief The driver Arrow library version (type: utf8).
+///
+/// \see AdbcConnectionGetInfo
+#define ADBC_INFO_DRIVER_ARROW_VERSION 102
+
+/// \brief Return metadata on catalogs, schemas, tables, and columns.
+///
+/// \see AdbcConnectionGetObjects
+#define ADBC_OBJECT_DEPTH_ALL 0
+/// \brief Return metadata on catalogs only.
+///
+/// \see AdbcConnectionGetObjects
+#define ADBC_OBJECT_DEPTH_CATALOGS 1
+/// \brief Return metadata on catalogs and schemas.
+///
+/// \see AdbcConnectionGetObjects
+#define ADBC_OBJECT_DEPTH_DB_SCHEMAS 2
+/// \brief Return metadata on catalogs, schemas, and tables.
+///
+/// \see AdbcConnectionGetObjects
+#define ADBC_OBJECT_DEPTH_TABLES 3
+/// \brief Return metadata on catalogs, schemas, tables, and columns.
+///
+/// \see AdbcConnectionGetObjects
+#define ADBC_OBJECT_DEPTH_COLUMNS ADBC_OBJECT_DEPTH_ALL
+
+/// \brief The name of the canonical option for whether autocommit is
+/// enabled.
+///
+/// \see AdbcConnectionSetOption
+#define ADBC_CONNECTION_OPTION_AUTOCOMMIT "adbc.connection.autocommit"
+
+/// \brief The name of the canonical option for whether the current
+/// connection should be restricted to being read-only.
+///
+/// \see AdbcConnectionSetOption
+#define ADBC_CONNECTION_OPTION_READ_ONLY "adbc.connection.readonly"
+
+/// \brief The name of the canonical option for setting the isolation
+/// level of a transaction.
+///
+/// Should only be used in conjunction with autocommit disabled and
+/// AdbcConnectionCommit / AdbcConnectionRollback. If the desired
+/// isolation level is not supported by a driver, it should return an
+/// appropriate error.
+///
+/// \see AdbcConnectionSetOption
+#define ADBC_CONNECTION_OPTION_ISOLATION_LEVEL \
+ "adbc.connection.transaction.isolation_level"
+
+/// \brief Use database or driver default isolation level
+///
+/// \see AdbcConnectionSetOption
+#define ADBC_OPTION_ISOLATION_LEVEL_DEFAULT \
+ "adbc.connection.transaction.isolation.default"
+
+/// \brief The lowest isolation level. Dirty reads are allowed, so one
+/// transaction may see not-yet-committed changes made by others.
+///
+/// \see AdbcConnectionSetOption
+#define ADBC_OPTION_ISOLATION_LEVEL_READ_UNCOMMITTED \
+ "adbc.connection.transaction.isolation.read_uncommitted"
+
+/// \brief Lock-based concurrency control keeps write locks until the
+/// end of the transaction, but read locks are released as soon as a
+/// SELECT is performed. Non-repeatable reads can occur in this
+/// isolation level.
+///
+/// More simply put, Read Committed is an isolation level that guarantees
+/// that any data read is committed at the moment it is read. It simply
+/// restricts the reader from seeing any intermediate, uncommitted,
+/// 'dirty' reads. It makes no promise whatsoever that if the transaction
+/// re-issues the read, it will find the same data; data is free to change
+/// after it is read.
+///
+/// \see AdbcConnectionSetOption
+#define ADBC_OPTION_ISOLATION_LEVEL_READ_COMMITTED \
+ "adbc.connection.transaction.isolation.read_committed"
+
+/// \brief Lock-based concurrency control keeps read AND write locks
+/// (acquired on selection data) until the end of the transaction.
+///
+/// However, range-locks are not managed, so phantom reads can occur.
+/// Write skew is possible at this isolation level in some systems.
+///
+/// \see AdbcConnectionSetOption
+#define ADBC_OPTION_ISOLATION_LEVEL_REPEATABLE_READ \
+ "adbc.connection.transaction.isolation.repeatable_read"
+
+/// \brief This isolation guarantees that all reads in the transaction
+/// will see a consistent snapshot of the database and the transaction
+/// should only successfully commit if no updates conflict with any
+/// concurrent updates made since that snapshot.
+///
+/// \see AdbcConnectionSetOption
+#define ADBC_OPTION_ISOLATION_LEVEL_SNAPSHOT \
+ "adbc.connection.transaction.isolation.snapshot"
+
+/// \brief Serializability requires read and write locks to be released
+/// only at the end of the transaction. This includes acquiring range-
+/// locks when a select query uses a ranged WHERE clause to avoid
+/// phantom reads.
+///
+/// \see AdbcConnectionSetOption
+#define ADBC_OPTION_ISOLATION_LEVEL_SERIALIZABLE \
+ "adbc.connection.transaction.isolation.serializable"
+
+/// \brief The central distinction between serializability and linearizability
+/// is that serializability is a global property; a property of an entire
+/// history of operations and transactions. Linearizability is a local
+/// property; a property of a single operation/transaction.
+///
+/// Linearizability can be viewed as a special case of strict serializability
+/// where transactions are restricted to consist of a single operation applied
+/// to a single object.
+///
+/// \see AdbcConnectionSetOption
+#define ADBC_OPTION_ISOLATION_LEVEL_LINEARIZABLE \
+ "adbc.connection.transaction.isolation.linearizable"
+
+/// \defgroup adbc-statement-ingestion Bulk Data Ingestion
+/// While it is possible to insert data via prepared statements, it can
+/// be more efficient to explicitly perform a bulk insert. For
+/// compatible drivers, this can be accomplished by setting up and
+/// executing a statement. Instead of setting a SQL query or Substrait
+/// plan, bind the source data via AdbcStatementBind, and set the name
+/// of the table to be created via AdbcStatementSetOption and the
+/// options below. Then, call AdbcStatementExecute with
+/// ADBC_OUTPUT_TYPE_UPDATE.
+///
+/// @{
+
+/// \brief The name of the target table for a bulk insert.
+///
+/// The driver should attempt to create the table if it does not
+/// exist. If the table exists but has a different schema,
+/// ADBC_STATUS_ALREADY_EXISTS should be raised. Else, data should be
+/// appended to the target table.
+#define ADBC_INGEST_OPTION_TARGET_TABLE "adbc.ingest.target_table"
+/// \brief Whether to create (the default) or append.
+#define ADBC_INGEST_OPTION_MODE "adbc.ingest.mode"
+/// \brief Create the table and insert data; error if the table exists.
+#define ADBC_INGEST_OPTION_MODE_CREATE "adbc.ingest.mode.create"
+/// \brief Do not create the table, and insert data; error if the
+/// table does not exist (ADBC_STATUS_NOT_FOUND) or does not match
+/// the schema of the data to append (ADBC_STATUS_ALREADY_EXISTS).
+#define ADBC_INGEST_OPTION_MODE_APPEND "adbc.ingest.mode.append"
+
+/// @}
+
+/// @}
+
+/// \defgroup adbc-database Database Initialization
+/// Clients first initialize a database, then create a connection
+/// (below). This gives the implementation a place to initialize and
+/// own any common connection state. For example, in-memory databases
+/// can place ownership of the actual database in this object.
+/// @{
+
+/// \brief An instance of a database.
+///
+/// Must be kept alive as long as any connections exist.
+struct ADBC_EXPORT AdbcDatabase {
+ /// \brief Opaque implementation-defined state.
+ /// This field is NULLPTR iff the connection is unintialized/freed.
+ void* private_data;
+ /// \brief The associated driver (used by the driver manager to help
+ /// track state).
+ struct AdbcDriver* private_driver;
+};
+
+/// @}
+
+/// \defgroup adbc-connection Connection Establishment
+/// Functions for creating, using, and releasing database connections.
+/// @{
+
+/// \brief An active database connection.
+///
+/// Provides methods for query execution, managing prepared
+/// statements, using transactions, and so on.
+///
+/// Connections are not required to be thread-safe, but they can be
+/// used from multiple threads so long as clients take care to
+/// serialize accesses to a connection.
+struct ADBC_EXPORT AdbcConnection {
+ /// \brief Opaque implementation-defined state.
+ /// This field is NULLPTR iff the connection is unintialized/freed.
+ void* private_data;
+ /// \brief The associated driver (used by the driver manager to help
+ /// track state).
+ struct AdbcDriver* private_driver;
+};
+
+/// @}
+
+/// \defgroup adbc-statement Managing Statements
+/// Applications should first initialize a statement with
+/// AdbcStatementNew. Then, the statement should be configured with
+/// functions like AdbcStatementSetSqlQuery and
+/// AdbcStatementSetOption. Finally, the statement can be executed
+/// with AdbcStatementExecuteQuery (or call AdbcStatementPrepare first
+/// to turn it into a prepared statement instead).
+/// @{
+
+/// \brief A container for all state needed to execute a database
+/// query, such as the query itself, parameters for prepared
+/// statements, driver parameters, etc.
+///
+/// Statements may represent queries or prepared statements.
+///
+/// Statements may be used multiple times and can be reconfigured
+/// (e.g. they can be reused to execute multiple different queries).
+/// However, executing a statement (and changing certain other state)
+/// will invalidate result sets obtained prior to that execution.
+///
+/// Multiple statements may be created from a single connection.
+/// However, the driver may block or error if they are used
+/// concurrently (whether from a single thread or multiple threads).
+///
+/// Statements are not required to be thread-safe, but they can be
+/// used from multiple threads so long as clients take care to
+/// serialize accesses to a statement.
+struct ADBC_EXPORT AdbcStatement {
+ /// \brief Opaque implementation-defined state.
+ /// This field is NULLPTR iff the connection is unintialized/freed.
+ void* private_data;
+
+ /// \brief The associated driver (used by the driver manager to help
+ /// track state).
+ struct AdbcDriver* private_driver;
+};
+
+/// \defgroup adbc-statement-partition Partitioned Results
+/// Some backends may internally partition the results. These
+/// partitions are exposed to clients who may wish to integrate them
+/// with a threaded or distributed execution model, where partitions
+/// can be divided among threads or machines and fetched in parallel.
+///
+/// To use partitioning, execute the statement with
+/// AdbcStatementExecutePartitions to get the partition descriptors.
+/// Call AdbcConnectionReadPartition to turn the individual
+/// descriptors into ArrowArrayStream instances. This may be done on
+/// a different connection than the one the partition was created
+/// with, or even in a different process on another machine.
+///
+/// Drivers are not required to support partitioning.
+///
+/// @{
+
+/// \brief The partitions of a distributed/partitioned result set.
+struct AdbcPartitions {
+ /// \brief The number of partitions.
+ size_t num_partitions;
+
+ /// \brief The partitions of the result set, where each entry (up to
+ /// num_partitions entries) is an opaque identifier that can be
+ /// passed to AdbcConnectionReadPartition.
+ const uint8_t** partitions;
+
+ /// \brief The length of each corresponding entry in partitions.
+ const size_t* partition_lengths;
+
+ /// \brief Opaque implementation-defined state.
+ /// This field is NULLPTR iff the connection is unintialized/freed.
+ void* private_data;
+
+ /// \brief Release the contained partitions.
+ ///
+ /// Unlike other structures, this is an embedded callback to make it
+ /// easier for the driver manager and driver to cooperate.
+ void (*release)(struct AdbcPartitions* partitions);
+};
+
+/// @}
+
+/// @}
+
+/// \defgroup adbc-driver Driver Initialization
+///
+/// These functions are intended to help support integration between a
+/// driver and the driver manager.
+/// @{
+
+/// \brief An instance of an initialized database driver.
+///
+/// This provides a common interface for vendor-specific driver
+/// initialization routines. Drivers should populate this struct, and
+/// applications can call ADBC functions through this struct, without
+/// worrying about multiple definitions of the same symbol.
+struct ADBC_EXPORT AdbcDriver {
+ /// \brief Opaque driver-defined state.
+ /// This field is NULL if the driver is unintialized/freed (but
+ /// it need not have a value even if the driver is initialized).
+ void* private_data;
+ /// \brief Opaque driver manager-defined state.
+ /// This field is NULL if the driver is unintialized/freed (but
+ /// it need not have a value even if the driver is initialized).
+ void* private_manager;
+
+ /// \brief Release the driver and perform any cleanup.
+ ///
+ /// This is an embedded callback to make it easier for the driver
+ /// manager and driver to cooperate.
+ AdbcStatusCode (*release)(struct AdbcDriver* driver, struct AdbcError* error);
+
+ AdbcStatusCode (*DatabaseInit)(struct AdbcDatabase*, struct AdbcError*);
+ AdbcStatusCode (*DatabaseNew)(struct AdbcDatabase*, struct AdbcError*);
+ AdbcStatusCode (*DatabaseSetOption)(struct AdbcDatabase*, const char*, const char*,
+ struct AdbcError*);
+ AdbcStatusCode (*DatabaseRelease)(struct AdbcDatabase*, struct AdbcError*);
+
+ AdbcStatusCode (*ConnectionCommit)(struct AdbcConnection*, struct AdbcError*);
+ AdbcStatusCode (*ConnectionGetInfo)(struct AdbcConnection*, uint32_t*, size_t,
+ struct ArrowArrayStream*, struct AdbcError*);
+ AdbcStatusCode (*ConnectionGetObjects)(struct AdbcConnection*, int, const char*,
+ const char*, const char*, const char**,
+ const char*, struct ArrowArrayStream*,
+ struct AdbcError*);
+ AdbcStatusCode (*ConnectionGetTableSchema)(struct AdbcConnection*, const char*,
+ const char*, const char*,
+ struct ArrowSchema*, struct AdbcError*);
+ AdbcStatusCode (*ConnectionGetTableTypes)(struct AdbcConnection*,
+ struct ArrowArrayStream*, struct AdbcError*);
+ AdbcStatusCode (*ConnectionInit)(struct AdbcConnection*, struct AdbcDatabase*,
+ struct AdbcError*);
+ AdbcStatusCode (*ConnectionNew)(struct AdbcConnection*, struct AdbcError*);
+ AdbcStatusCode (*ConnectionSetOption)(struct AdbcConnection*, const char*, const char*,
+ struct AdbcError*);
+ AdbcStatusCode (*ConnectionReadPartition)(struct AdbcConnection*, const uint8_t*,
+ size_t, struct ArrowArrayStream*,
+ struct AdbcError*);
+ AdbcStatusCode (*ConnectionRelease)(struct AdbcConnection*, struct AdbcError*);
+ AdbcStatusCode (*ConnectionRollback)(struct AdbcConnection*, struct AdbcError*);
+
+ AdbcStatusCode (*StatementBind)(struct AdbcStatement*, struct ArrowArray*,
+ struct ArrowSchema*, struct AdbcError*);
+ AdbcStatusCode (*StatementBindStream)(struct AdbcStatement*, struct ArrowArrayStream*,
+ struct AdbcError*);
+ AdbcStatusCode (*StatementExecuteQuery)(struct AdbcStatement*, struct ArrowArrayStream*,
+ int64_t*, struct AdbcError*);
+ AdbcStatusCode (*StatementExecutePartitions)(struct AdbcStatement*, struct ArrowSchema*,
+ struct AdbcPartitions*, int64_t*,
+ struct AdbcError*);
+ AdbcStatusCode (*StatementGetParameterSchema)(struct AdbcStatement*,
+ struct ArrowSchema*, struct AdbcError*);
+ AdbcStatusCode (*StatementNew)(struct AdbcConnection*, struct AdbcStatement*,
+ struct AdbcError*);
+ AdbcStatusCode (*StatementPrepare)(struct AdbcStatement*, struct AdbcError*);
+ AdbcStatusCode (*StatementRelease)(struct AdbcStatement*, struct AdbcError*);
+ AdbcStatusCode (*StatementSetOption)(struct AdbcStatement*, const char*, const char*,
+ struct AdbcError*);
+ AdbcStatusCode (*StatementSetSqlQuery)(struct AdbcStatement*, const char*,
+ struct AdbcError*);
+ AdbcStatusCode (*StatementSetSubstraitPlan)(struct AdbcStatement*, const uint8_t*,
+ size_t, struct AdbcError*);
+};
+
+/// @}
+
+/// \addtogroup adbc-database
+/// @{
+
+/// \brief Allocate a new (but uninitialized) database.
+ADBC_EXPORT
+AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error);
+
+/// \brief Set a char* option.
+///
+/// Options may be set before AdbcDatabaseInit. Some drivers may
+/// support setting options after initialization as well.
+///
+/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized
+ADBC_EXPORT
+AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key,
+ const char* value, struct AdbcError* error);
+
+/// \brief Finish setting options and initialize the database.
+///
+/// Some drivers may support setting options after initialization
+/// as well.
+ADBC_EXPORT
+AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error);
+
+/// \brief Destroy this database. No connections may exist.
+/// \param[in] database The database to release.
+/// \param[out] error An optional location to return an error
+/// message if necessary.
+ADBC_EXPORT
+AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database,
+ struct AdbcError* error);
+
+/// @}
+
+/// \addtogroup adbc-connection
+/// @{
+
+/// \brief Allocate a new (but uninitialized) connection.
+ADBC_EXPORT
+AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection,
+ struct AdbcError* error);
+
+/// \brief Set a char* option.
+///
+/// Options may be set before AdbcConnectionInit. Some drivers may
+/// support setting options after initialization as well.
+///
+/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized
+ADBC_EXPORT
+AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key,
+ const char* value, struct AdbcError* error);
+
+/// \brief Finish setting options and initialize the connection.
+///
+/// Some drivers may support setting options after initialization
+/// as well.
+ADBC_EXPORT
+AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
+ struct AdbcDatabase* database, struct AdbcError* error);
+
+/// \brief Destroy this connection.
+///
+/// \param[in] connection The connection to release.
+/// \param[out] error An optional location to return an error
+/// message if necessary.
+ADBC_EXPORT
+AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection,
+ struct AdbcError* error);
+
+/// \defgroup adbc-connection-metadata Metadata
+/// Functions for retrieving metadata about the database.
+///
+/// Generally, these functions return an ArrowArrayStream that can be
+/// consumed to get the metadata as Arrow data. The returned metadata
+/// has an expected schema given in the function docstring. Schema
+/// fields are nullable unless otherwise marked. While no
+/// AdbcStatement is used in these functions, the result set may count
+/// as an active statement to the driver for the purposes of
+/// concurrency management (e.g. if the driver has a limit on
+/// concurrent active statements and it must execute a SQL query
+/// internally in order to implement the metadata function).
+///
+/// Some functions accept "search pattern" arguments, which are
+/// strings that can contain the special character "%" to match zero
+/// or more characters, or "_" to match exactly one character. (See
+/// the documentation of DatabaseMetaData in JDBC or "Pattern Value
+/// Arguments" in the ODBC documentation.) Escaping is not currently
+/// supported.
+///
+/// @{
+
+/// \brief Get metadata about the database/driver.
+///
+/// The result is an Arrow dataset with the following schema:
+///
+/// Field Name | Field Type
+/// ----------------------------|------------------------
+/// info_name | uint32 not null
+/// info_value | INFO_SCHEMA
+///
+/// INFO_SCHEMA is a dense union with members:
+///
+/// Field Name (Type Code) | Field Type
+/// ----------------------------|------------------------
+/// string_value (0) | utf8
+/// bool_value (1) | bool
+/// int64_value (2) | int64
+/// int32_bitmask (3) | int32
+/// string_list (4) | list
+/// int32_to_int32_list_map (5) | map>
+///
+/// Each metadatum is identified by an integer code. The recognized
+/// codes are defined as constants. Codes [0, 10_000) are reserved
+/// for ADBC usage. Drivers/vendors will ignore requests for
+/// unrecognized codes (the row will be omitted from the result).
+///
+/// \param[in] connection The connection to query.
+/// \param[in] info_codes A list of metadata codes to fetch, or NULL
+/// to fetch all.
+/// \param[in] info_codes_length The length of the info_codes
+/// parameter. Ignored if info_codes is NULL.
+/// \param[out] out The result set.
+/// \param[out] error Error details, if an error occurs.
+ADBC_EXPORT
+AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection,
+ uint32_t* info_codes, size_t info_codes_length,
+ struct ArrowArrayStream* out,
+ struct AdbcError* error);
+
+/// \brief Get a hierarchical view of all catalogs, database schemas,
+/// tables, and columns.
+///
+/// The result is an Arrow dataset with the following schema:
+///
+/// | Field Name | Field Type |
+/// |--------------------------|-------------------------|
+/// | catalog_name | utf8 |
+/// | catalog_db_schemas | list |
+///
+/// DB_SCHEMA_SCHEMA is a Struct with fields:
+///
+/// | Field Name | Field Type |
+/// |--------------------------|-------------------------|
+/// | db_schema_name | utf8 |
+/// | db_schema_tables | list |
+///
+/// TABLE_SCHEMA is a Struct with fields:
+///
+/// | Field Name | Field Type |
+/// |--------------------------|-------------------------|
+/// | table_name | utf8 not null |
+/// | table_type | utf8 not null |
+/// | table_columns | list |
+/// | table_constraints | list |
+///
+/// COLUMN_SCHEMA is a Struct with fields:
+///
+/// | Field Name | Field Type | Comments |
+/// |--------------------------|-------------------------|----------|
+/// | column_name | utf8 not null | |
+/// | ordinal_position | int32 | (1) |
+/// | remarks | utf8 | (2) |
+/// | xdbc_data_type | int16 | (3) |
+/// | xdbc_type_name | utf8 | (3) |
+/// | xdbc_column_size | int32 | (3) |
+/// | xdbc_decimal_digits | int16 | (3) |
+/// | xdbc_num_prec_radix | int16 | (3) |
+/// | xdbc_nullable | int16 | (3) |
+/// | xdbc_column_def | utf8 | (3) |
+/// | xdbc_sql_data_type | int16 | (3) |
+/// | xdbc_datetime_sub | int16 | (3) |
+/// | xdbc_char_octet_length | int32 | (3) |
+/// | xdbc_is_nullable | utf8 | (3) |
+/// | xdbc_scope_catalog | utf8 | (3) |
+/// | xdbc_scope_schema | utf8 | (3) |
+/// | xdbc_scope_table | utf8 | (3) |
+/// | xdbc_is_autoincrement | bool | (3) |
+/// | xdbc_is_generatedcolumn | bool | (3) |
+///
+/// 1. The column's ordinal position in the table (starting from 1).
+/// 2. Database-specific description of the column.
+/// 3. Optional value. Should be null if not supported by the driver.
+/// xdbc_ values are meant to provide JDBC/ODBC-compatible metadata
+/// in an agnostic manner.
+///
+/// CONSTRAINT_SCHEMA is a Struct with fields:
+///
+/// | Field Name | Field Type | Comments |
+/// |--------------------------|-------------------------|----------|
+/// | constraint_name | utf8 | |
+/// | constraint_type | utf8 not null | (1) |
+/// | constraint_column_names | list not null | (2) |
+/// | constraint_column_usage | list | (3) |
+///
+/// 1. One of 'CHECK', 'FOREIGN KEY', 'PRIMARY KEY', or 'UNIQUE'.
+/// 2. The columns on the current table that are constrained, in
+/// order.
+/// 3. For FOREIGN KEY only, the referenced table and columns.
+///
+/// USAGE_SCHEMA is a Struct with fields:
+///
+/// | Field Name | Field Type |
+/// |--------------------------|-------------------------|
+/// | fk_catalog | utf8 |
+/// | fk_db_schema | utf8 |
+/// | fk_table | utf8 not null |
+/// | fk_column_name | utf8 not null |
+///
+/// \param[in] connection The database connection.
+/// \param[in] depth The level of nesting to display. If 0, display
+/// all levels. If 1, display only catalogs (i.e. catalog_schemas
+/// will be null). If 2, display only catalogs and schemas
+/// (i.e. db_schema_tables will be null), and so on.
+/// \param[in] catalog Only show tables in the given catalog. If NULL,
+/// do not filter by catalog. If an empty string, only show tables
+/// without a catalog. May be a search pattern (see section
+/// documentation).
+/// \param[in] db_schema Only show tables in the given database schema. If
+/// NULL, do not filter by database schema. If an empty string, only show
+/// tables without a database schema. May be a search pattern (see section
+/// documentation).
+/// \param[in] table_name Only show tables with the given name. If NULL, do not
+/// filter by name. May be a search pattern (see section documentation).
+/// \param[in] table_type Only show tables matching one of the given table
+/// types. If NULL, show tables of any type. Valid table types can be fetched
+/// from GetTableTypes. Terminate the list with a NULL entry.
+/// \param[in] column_name Only show columns with the given name. If
+/// NULL, do not filter by name. May be a search pattern (see
+/// section documentation).
+/// \param[out] out The result set.
+/// \param[out] error Error details, if an error occurs.
+ADBC_EXPORT
+AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth,
+ const char* catalog, const char* db_schema,
+ const char* table_name, const char** table_type,
+ const char* column_name,
+ struct ArrowArrayStream* out,
+ struct AdbcError* error);
+
+/// \brief Get the Arrow schema of a table.
+///
+/// \param[in] connection The database connection.
+/// \param[in] catalog The catalog (or nullptr if not applicable).
+/// \param[in] db_schema The database schema (or nullptr if not applicable).
+/// \param[in] table_name The table name.
+/// \param[out] schema The table schema.
+/// \param[out] error Error details, if an error occurs.
+ADBC_EXPORT
+AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
+ const char* catalog, const char* db_schema,
+ const char* table_name,
+ struct ArrowSchema* schema,
+ struct AdbcError* error);
+
+/// \brief Get a list of table types in the database.
+///
+/// The result is an Arrow dataset with the following schema:
+///
+/// Field Name | Field Type
+/// ---------------|--------------
+/// table_type | utf8 not null
+///
+/// \param[in] connection The database connection.
+/// \param[out] out The result set.
+/// \param[out] error Error details, if an error occurs.
+ADBC_EXPORT
+AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
+ struct ArrowArrayStream* out,
+ struct AdbcError* error);
+
+/// @}
+
+/// \defgroup adbc-connection-partition Partitioned Results
+/// Some databases may internally partition the results. These
+/// partitions are exposed to clients who may wish to integrate them
+/// with a threaded or distributed execution model, where partitions
+/// can be divided among threads or machines for processing.
+///
+/// Drivers are not required to support partitioning.
+///
+/// Partitions are not ordered. If the result set is sorted,
+/// implementations should return a single partition.
+///
+/// @{
+
+/// \brief Construct a statement for a partition of a query. The
+/// results can then be read independently.
+///
+/// A partition can be retrieved from AdbcPartitions.
+///
+/// \param[in] connection The connection to use. This does not have
+/// to be the same connection that the partition was created on.
+/// \param[in] serialized_partition The partition descriptor.
+/// \param[in] serialized_length The partition descriptor length.
+/// \param[out] out The result set.
+/// \param[out] error Error details, if an error occurs.
+ADBC_EXPORT
+AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection,
+ const uint8_t* serialized_partition,
+ size_t serialized_length,
+ struct ArrowArrayStream* out,
+ struct AdbcError* error);
+
+/// @}
+
+/// \defgroup adbc-connection-transaction Transaction Semantics
+///
+/// Connections start out in auto-commit mode by default (if
+/// applicable for the given vendor). Use AdbcConnectionSetOption and
+/// ADBC_CONNECTION_OPTION_AUTO_COMMIT to change this.
+///
+/// @{
+
+/// \brief Commit any pending transactions. Only used if autocommit is
+/// disabled.
+///
+/// Behavior is undefined if this is mixed with SQL transaction
+/// statements.
+ADBC_EXPORT
+AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection,
+ struct AdbcError* error);
+
+/// \brief Roll back any pending transactions. Only used if autocommit
+/// is disabled.
+///
+/// Behavior is undefined if this is mixed with SQL transaction
+/// statements.
+ADBC_EXPORT
+AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection,
+ struct AdbcError* error);
+
+/// @}
+
+/// @}
+
+/// \addtogroup adbc-statement
+/// @{
+
+/// \brief Create a new statement for a given connection.
+///
+/// Set options on the statement, then call AdbcStatementExecuteQuery
+/// or AdbcStatementPrepare.
+ADBC_EXPORT
+AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection,
+ struct AdbcStatement* statement, struct AdbcError* error);
+
+/// \brief Destroy a statement.
+/// \param[in] statement The statement to release.
+/// \param[out] error An optional location to return an error
+/// message if necessary.
+ADBC_EXPORT
+AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement,
+ struct AdbcError* error);
+
+/// \brief Execute a statement and get the results.
+///
+/// This invalidates any prior result sets.
+///
+/// \param[in] statement The statement to execute.
+/// \param[out] out The results. Pass NULL if the client does not
+/// expect a result set.
+/// \param[out] rows_affected The number of rows affected if known,
+/// else -1. Pass NULL if the client does not want this information.
+/// \param[out] error An optional location to return an error
+/// message if necessary.
+ADBC_EXPORT
+AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
+ struct ArrowArrayStream* out,
+ int64_t* rows_affected, struct AdbcError* error);
+
+/// \brief Turn this statement into a prepared statement to be
+/// executed multiple times.
+///
+/// This invalidates any prior result sets.
+ADBC_EXPORT
+AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement,
+ struct AdbcError* error);
+
+/// \defgroup adbc-statement-sql SQL Semantics
+/// Functions for executing SQL queries, or querying SQL-related
+/// metadata. Drivers are not required to support both SQL and
+/// Substrait semantics. If they do, it may be via converting
+/// between representations internally.
+/// @{
+
+/// \brief Set the SQL query to execute.
+///
+/// The query can then be executed with AdbcStatementExecute. For
+/// queries expected to be executed repeatedly, AdbcStatementPrepare
+/// the statement first.
+///
+/// \param[in] statement The statement.
+/// \param[in] query The query to execute.
+/// \param[out] error Error details, if an error occurs.
+ADBC_EXPORT
+AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
+ const char* query, struct AdbcError* error);
+
+/// @}
+
+/// \defgroup adbc-statement-substrait Substrait Semantics
+/// Functions for executing Substrait plans, or querying
+/// Substrait-related metadata. Drivers are not required to support
+/// both SQL and Substrait semantics. If they do, it may be via
+/// converting between representations internally.
+/// @{
+
+/// \brief Set the Substrait plan to execute.
+///
+/// The query can then be executed with AdbcStatementExecute. For
+/// queries expected to be executed repeatedly, AdbcStatementPrepare
+/// the statement first.
+///
+/// \param[in] statement The statement.
+/// \param[in] plan The serialized substrait.Plan to execute.
+/// \param[in] length The length of the serialized plan.
+/// \param[out] error Error details, if an error occurs.
+ADBC_EXPORT
+AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement,
+ const uint8_t* plan, size_t length,
+ struct AdbcError* error);
+
+/// @}
+
+/// \brief Bind Arrow data. This can be used for bulk inserts or
+/// prepared statements.
+///
+/// \param[in] statement The statement to bind to.
+/// \param[in] values The values to bind. The driver will call the
+/// release callback itself, although it may not do this until the
+/// statement is released.
+/// \param[in] schema The schema of the values to bind.
+/// \param[out] error An optional location to return an error message
+/// if necessary.
+ADBC_EXPORT
+AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement,
+ struct ArrowArray* values, struct ArrowSchema* schema,
+ struct AdbcError* error);
+
+/// \brief Bind Arrow data. This can be used for bulk inserts or
+/// prepared statements.
+/// \param[in] statement The statement to bind to.
+/// \param[in] stream The values to bind. The driver will call the
+/// release callback itself, although it may not do this until the
+/// statement is released.
+/// \param[out] error An optional location to return an error message
+/// if necessary.
+ADBC_EXPORT
+AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
+ struct ArrowArrayStream* stream,
+ struct AdbcError* error);
+
+/// \brief Get the schema for bound parameters.
+///
+/// This retrieves an Arrow schema describing the number, names, and
+/// types of the parameters in a parameterized statement. The fields
+/// of the schema should be in order of the ordinal position of the
+/// parameters; named parameters should appear only once.
+///
+/// If the parameter does not have a name, or the name cannot be
+/// determined, the name of the corresponding field in the schema will
+/// be an empty string. If the type cannot be determined, the type of
+/// the corresponding field will be NA (NullType).
+///
+/// This should be called after AdbcStatementPrepare.
+///
+/// \return ADBC_STATUS_NOT_IMPLEMENTED if the schema cannot be determined.
+ADBC_EXPORT
+AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
+ struct ArrowSchema* schema,
+ struct AdbcError* error);
+
+/// \brief Set a string option on a statement.
+ADBC_EXPORT
+AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key,
+ const char* value, struct AdbcError* error);
+
+/// \addtogroup adbc-statement-partition
+/// @{
+
+/// \brief Execute a statement and get the results as a partitioned
+/// result set.
+///
+/// \param[in] statement The statement to execute.
+/// \param[out] schema The schema of the result set.
+/// \param[out] partitions The result partitions.
+/// \param[out] rows_affected The number of rows affected if known,
+/// else -1. Pass NULL if the client does not want this information.
+/// \param[out] error An optional location to return an error
+/// message if necessary.
+/// \return ADBC_STATUS_NOT_IMPLEMENTED if the driver does not support
+/// partitioned results
+ADBC_EXPORT
+AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement,
+ struct ArrowSchema* schema,
+ struct AdbcPartitions* partitions,
+ int64_t* rows_affected,
+ struct AdbcError* error);
+
+/// @}
+
+/// @}
+
+/// \addtogroup adbc-driver
+/// @{
+
+/// \brief Common entry point for drivers via the driver manager
+/// (which uses dlopen(3)/LoadLibrary). The driver manager is told
+/// to load a library and call a function of this type to load the
+/// driver.
+///
+/// Although drivers may choose any name for this function, the
+/// recommended name is "AdbcDriverInit".
+///
+/// \param[in] version The ADBC revision to attempt to initialize (see
+/// ADBC_VERSION_1_0_0).
+/// \param[out] driver The table of function pointers to
+/// initialize. Should be a pointer to the appropriate struct for
+/// the given version (see the documentation for the version).
+/// \param[out] error An optional location to return an error message
+/// if necessary.
+/// \return ADBC_STATUS_OK if the driver was initialized, or
+/// ADBC_STATUS_NOT_IMPLEMENTED if the version is not supported. In
+/// that case, clients may retry with a different version.
+typedef AdbcStatusCode (*AdbcDriverInitFunc)(int version, void* driver,
+ struct AdbcError* error);
+
+/// @}
+
+#endif // ADBC
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt
index 628b02b9d28..bedbe440996 100644
--- a/cpp/src/arrow/flight/sql/CMakeLists.txt
+++ b/cpp/src/arrow/flight/sql/CMakeLists.txt
@@ -37,6 +37,8 @@ set_source_files_properties(${FLIGHT_SQL_GENERATED_PROTO_FILES} PROPERTIES GENER
add_custom_target(flight_sql_protobuf_gen ALL DEPENDS ${FLIGHT_SQL_GENERATED_PROTO_FILES})
set(ARROW_FLIGHT_SQL_SRCS
+ adbc_driver.cc
+ adbc_driver_internal.cc
server.cc
sql_info_internal.cc
column_metadata.cc
@@ -94,7 +96,7 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES)
example/sqlite_server.cc
example/sqlite_tables_schema_batch_reader.cc)
- set(ARROW_FLIGHT_SQL_TEST_SRCS server_test.cc)
+ set(ARROW_FLIGHT_SQL_TEST_SRCS adbc_driver_test.cc server_test.cc)
set(ARROW_FLIGHT_SQL_TEST_LIBS ${SQLite3_LIBRARIES})
set(ARROW_FLIGHT_SQL_ACERO_SRCS example/acero_server.cc)
@@ -124,8 +126,15 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES)
STATIC_LINK_LIBS
${ARROW_FLIGHT_SQL_TEST_LINK_LIBS}
${ARROW_FLIGHT_SQL_TEST_LIBS}
+ nanoarrow
+ AdbcValidation::adbc_validation
+ # Needs to come twice since the validation library
+ # also uses ADBC symbols, which get provided by
+ # libarrow_flight_sql here.
+ ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS}
EXTRA_INCLUDES
- "${CMAKE_CURRENT_BINARY_DIR}/../"
+ # adbc_validation.h needs adbc.h
+ "${CMAKE_SOURCE_DIR}/../format/"
LABELS
"arrow_flight_sql")
diff --git a/cpp/src/arrow/flight/sql/adbc_driver.cc b/cpp/src/arrow/flight/sql/adbc_driver.cc
new file mode 100644
index 00000000000..96b93f31827
--- /dev/null
+++ b/cpp/src/arrow/flight/sql/adbc_driver.cc
@@ -0,0 +1,2062 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "arrow/array/array_binary.h"
+#include "arrow/array/array_nested.h"
+#include "arrow/array/builder_base.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_nested.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/array/builder_union.h"
+#include "arrow/c/bridge.h"
+#include "arrow/config.h"
+#include "arrow/flight/client.h"
+#include "arrow/flight/sql/adbc_driver_internal.h"
+#include "arrow/flight/sql/client.h"
+#include "arrow/flight/sql/server.h"
+#include "arrow/flight/sql/types.h"
+#include "arrow/flight/sql/visibility.h"
+#include "arrow/io/memory.h"
+#include "arrow/io/type_fwd.h"
+#include "arrow/ipc/dictionary.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/config.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/value_parsing.h"
+
+#ifdef ARROW_COMPUTE
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/exec.h"
+#endif
+
+namespace arrow::flight::sql {
+
+using arrow::internal::checked_cast;
+
+namespace {
+static const std::vector> kDefaultTypeMapping = {
+ {Type::BINARY, "BLOB"},
+ {Type::BOOL, "BOOLEAN"},
+ {Type::DATE32, "DATE"},
+ {Type::DATE64, "DATE"},
+ {Type::DECIMAL128, "NUMERIC"},
+ {Type::DECIMAL256, "NUMERIC"},
+ {Type::DOUBLE, "DOUBLE PRECISION"},
+ {Type::FLOAT, "REAL"},
+ {Type::INT16, "SMALLINT"},
+ {Type::INT32, "INT"},
+ {Type::INT64, "BIGINT"},
+ {Type::LARGE_BINARY, "BLOB"},
+ {Type::LARGE_STRING, "TEXT"},
+ {Type::STRING, "TEXT"},
+ {Type::TIME32, "TIME"},
+ {Type::TIME64, "TIME"},
+ {Type::TIMESTAMP, "TIMESTAMP"},
+};
+
+/// \brief Client-side configuration to help paper over SQL dialect differences
+///
+/// This is needed when we have to generate SQL to implement ADBC
+/// features that have no direct Flight SQL equivalent. In
+/// particular, it's needed for bulk ingestion to a target table.
+struct FlightSqlQuirks {
+ /// A mapping from Arrow type to SQL type string
+ std::unordered_map ingest_type_mapping;
+
+ FlightSqlQuirks()
+ : ingest_type_mapping(kDefaultTypeMapping.begin(), kDefaultTypeMapping.end()) {}
+
+ bool UpdateTypeMapping(std::string_view type_name, const char* value) {
+ if (type_name == "binary") {
+ ingest_type_mapping[Type::BINARY] = value;
+ } else if (type_name == "bool") {
+ ingest_type_mapping[Type::BOOL] = value;
+ } else if (type_name == "date32") {
+ ingest_type_mapping[Type::DATE32] = value;
+ } else if (type_name == "date64") {
+ ingest_type_mapping[Type::DATE64] = value;
+ } else if (type_name == "decimal128") {
+ ingest_type_mapping[Type::DECIMAL128] = value;
+ } else if (type_name == "decimal256") {
+ ingest_type_mapping[Type::DECIMAL256] = value;
+ } else if (type_name == "double") {
+ ingest_type_mapping[Type::DOUBLE] = value;
+ } else if (type_name == "float") {
+ ingest_type_mapping[Type::FLOAT] = value;
+ } else if (type_name == "int16") {
+ ingest_type_mapping[Type::INT16] = value;
+ } else if (type_name == "int32") {
+ ingest_type_mapping[Type::INT32] = value;
+ } else if (type_name == "int64") {
+ ingest_type_mapping[Type::INT64] = value;
+ } else if (type_name == "large_binary") {
+ ingest_type_mapping[Type::LARGE_BINARY] = value;
+ } else if (type_name == "large_string") {
+ ingest_type_mapping[Type::LARGE_STRING] = value;
+ } else if (type_name == "string") {
+ ingest_type_mapping[Type::STRING] = value;
+ } else if (type_name == "time32") {
+ ingest_type_mapping[Type::TIME32] = value;
+ } else if (type_name == "time64") {
+ ingest_type_mapping[Type::TIME64] = value;
+ } else if (type_name == "timestamp") {
+ ingest_type_mapping[Type::TIMESTAMP] = value;
+ } else {
+ return false;
+ }
+ return true;
+ }
+};
+
+/// Config options that map to FlightClientOptions
+constexpr std::string_view kClientOptionTlsRootCerts =
+ "arrow.flight.sql.client_option.tls_root_certs";
+constexpr std::string_view kClientOptionOverrideHostname =
+ "arrow.flight.sql.client_option.override_hostname";
+constexpr std::string_view kClientOptionCertChain =
+ "arrow.flight.sql.client_option.cert_chain";
+constexpr std::string_view kClientOptionPrivateKey =
+ "arrow.flight.sql.client_option.private_key";
+constexpr std::string_view kClientOptionGenericIntOption =
+ "arrow.flight.sql.client_option.generic_int_option.";
+constexpr std::string_view kClientOptionGenericStringOption =
+ "arrow.flight.sql.client_option.generic_string_option.";
+constexpr std::string_view kClientOptionDisableServerVerification =
+ "arrow.flight.sql.client_option.disable_server_verification";
+/// Config option to enable JDBC driver-like authorization
+constexpr std::string_view kAuthorizationHeaderKey =
+ "arrow.flight.sql.authorization_header";
+/// Config options used to override the type mapping in FlightSqlQuirks
+constexpr std::string_view kIngestTypePrefix = "arrow.flight.sql.quirks.ingest_type.";
+/// Explicitly specify the Substrait version for Flight SQL (although
+/// Substrait will eventually embed this into the plan itself)
+constexpr std::string_view kStatementSubstraitVersionKey =
+ "arrow.flight.sql.substrait.version";
+/// Attach arbitrary key-value headers via Flight
+constexpr std::string_view kCallHeaderPrefix = "arrow.flight.sql.rpc.call_header.";
+/// A timeout for any DoGet requests
+constexpr std::string_view kConnectionTimeoutFetchKey =
+ "arrow.flight.sql.rpc.timeout_seconds.fetch";
+/// A timeout for any GetFlightInfo requests
+constexpr std::string_view kConnectionTimeoutQueryKey =
+ "arrow.flight.sql.rpc.timeout_seconds.query";
+/// A timeout for any DoPut requests, or miscellaneous DoAction requests
+constexpr std::string_view kConnectionTimeoutUpdateKey =
+ "arrow.flight.sql.rpc.timeout_seconds.update";
+constexpr std::string_view kConnectionOptionAutocommit =
+ ADBC_CONNECTION_OPTION_AUTOCOMMIT;
+constexpr std::string_view kIngestOptionMode = ADBC_INGEST_OPTION_MODE;
+constexpr std::string_view kIngestOptionModeAppend = ADBC_INGEST_OPTION_MODE_APPEND;
+constexpr std::string_view kIngestOptionModeCreate = ADBC_INGEST_OPTION_MODE_CREATE;
+constexpr std::string_view kIngestOptionTargetTable = ADBC_INGEST_OPTION_TARGET_TABLE;
+constexpr std::string_view kOptionValueEnabled = ADBC_OPTION_VALUE_ENABLED;
+constexpr std::string_view kOptionValueDisabled = ADBC_OPTION_VALUE_DISABLED;
+
+enum class CallContext {
+ kFetch,
+ kQuery,
+ kUpdate,
+};
+
+static const char kAuthorizationHeader[] = "authorization";
+
+/// Implement auth in the same way as the Flight SQL JDBC driver
+/// 1. Client ---[authorization: Basic XXX ]--> Server
+/// 2. Client <--[authorization: Bearer YYY]--- Server
+/// 3. Client ---[authorization: Bearer YYY]--> Server
+class ClientAuthorizationMiddlewareFactory : public ClientMiddlewareFactory {
+ public:
+ explicit ClientAuthorizationMiddlewareFactory(std::string initial_header)
+ : token_(std::move(initial_header)) {}
+
+ void StartCall(const CallInfo& info,
+ std::unique_ptr* middleware) override {
+ // TODO: maybe we want a shared_ptr and an atomic reference?
+ std::lock_guard guard(token_mutex_);
+ // Middleware instances outlive the factory
+ *middleware = std::make_unique(this, token_);
+ }
+
+ void UpdateToken(std::string_view token) {
+ std::lock_guard guard(token_mutex_);
+ token_ = std::string(token);
+ }
+
+ private:
+ std::mutex token_mutex_;
+ std::string token_;
+
+ class Impl : public ClientMiddleware {
+ public:
+ explicit Impl(ClientAuthorizationMiddlewareFactory* factory, std::string token)
+ : factory_(factory), token_(std::move(token)) {}
+ void SendingHeaders(AddCallHeaders* outgoing_headers) override {
+ outgoing_headers->AddHeader(kAuthorizationHeader, token_);
+ }
+ void ReceivedHeaders(const CallHeaders& incoming_headers) override {
+ // Assume server isn't going to send multiple authorization headers
+ auto it = incoming_headers.find(std::string_view(kAuthorizationHeader));
+ if (it == incoming_headers.end()) return;
+ if (it->second == token_) return;
+ factory_->UpdateToken(it->second);
+ }
+ void CallCompleted(const Status&) override {}
+
+ private:
+ ClientAuthorizationMiddlewareFactory* factory_;
+ std::string token_;
+ };
+};
+
+/// \brief AdbcDatabase implementation
+class FlightSqlDatabaseImpl {
+ public:
+ AdbcStatusCode Connect(std::unique_ptr* client,
+ struct AdbcError* error) {
+ std::lock_guard guard(mutex_);
+ if (!location_.has_value()) {
+ SetError(error, "Cannot create connection from uninitialized database");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+
+ std::unique_ptr flight_client;
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, FlightClient::Connect(*location_, client_options_).Value(&flight_client));
+ *client = std::make_unique(std::move(flight_client));
+ ++connection_count_;
+ return ADBC_STATUS_OK;
+ }
+
+ const std::shared_ptr& quirks() const { return quirks_; }
+ FlightCallOptions MakeCallOptions(CallContext context) const {
+ FlightCallOptions options;
+ for (const auto& header : call_headers_) {
+ options.headers.emplace_back(header.first, header.second);
+ }
+ return options;
+ }
+
+ AdbcStatusCode Init(struct AdbcError* error) {
+ std::lock_guard guard(mutex_);
+
+ if (location_.has_value()) {
+ SetError(error, "Database already initialized");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+
+ auto it = options_.find("uri");
+ if (it == options_.end()) {
+ SetError(error, "Must provide 'uri' option");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+
+ Location location;
+ ADBC_ARROW_RETURN_NOT_OK(error, Location::Parse(it->second).Value(&location));
+ location_ = location;
+
+ options_.clear();
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error) {
+ if (key == nullptr) {
+ SetError(error, "Key must not be null");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+
+ std::string_view key_view(key);
+ std::string_view val_view = value ? value : "";
+ if (key_view.rfind(kIngestTypePrefix, 0) == 0) {
+ // Changing the mapping from Arrow type <-> SQL type name, for when we
+ // do a bulk ingest and have to generate a CREATE TABLE statement
+ const std::string_view type_name = key_view.substr(kIngestTypePrefix.size());
+ if (!quirks_->UpdateTypeMapping(type_name, value)) {
+ SetError(error, "Unknown option value ", key_view, "=", val_view, ": type name ",
+ type_name, " is not recognized");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ return ADBC_STATUS_OK;
+ } else if (key_view.rfind(kCallHeaderPrefix, 0) == 0) {
+ // Add a custom header to all outgoing calls (can also be set at
+ // connection, statement level)
+ std::string header(key_view.substr(kCallHeaderPrefix.size()));
+ if (value == nullptr) {
+ call_headers_.erase(header);
+ } else {
+ call_headers_.insert({std::move(header), std::string(val_view)});
+ }
+ return ADBC_STATUS_OK;
+ } else if (key_view == kClientOptionTlsRootCerts) {
+ client_options_.tls_root_certs = val_view;
+ return ADBC_STATUS_OK;
+ } else if (key_view == kClientOptionOverrideHostname) {
+ client_options_.override_hostname = val_view;
+ return ADBC_STATUS_OK;
+ } else if (key_view == kClientOptionCertChain) {
+ client_options_.cert_chain = val_view;
+ return ADBC_STATUS_OK;
+ } else if (key_view == kClientOptionPrivateKey) {
+ client_options_.private_key = val_view;
+ return ADBC_STATUS_OK;
+ } else if (key_view == kClientOptionDisableServerVerification) {
+ if (val_view == kOptionValueEnabled) {
+ client_options_.disable_server_verification = true;
+ return ADBC_STATUS_OK;
+ } else if (val_view == kOptionValueDisabled) {
+ client_options_.disable_server_verification = false;
+ return ADBC_STATUS_OK;
+ }
+ SetError(error,
+ "Invalid boolean value for client option disable_server_verification: ",
+ val_view);
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ } else if (key_view.rfind(kClientOptionGenericIntOption, 0) == 0) {
+ const std::string_view option_name =
+ key_view.substr(kClientOptionGenericIntOption.size());
+ int32_t option_value = 0;
+ if (!arrow::internal::StringConverter().Convert(
+ Int32Type(), val_view.data(), val_view.size(), &option_value)) {
+ SetError(error,
+ "Invalid integer value for client option generic_options: ", val_view);
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ client_options_.generic_options.emplace_back(option_name, option_value);
+ return ADBC_STATUS_OK;
+ } else if (key_view.rfind(kClientOptionGenericStringOption, 0) == 0) {
+ const std::string_view option_name =
+ key_view.substr(kClientOptionGenericStringOption.size());
+ client_options_.generic_options.emplace_back(option_name, std::string(val_view));
+ return ADBC_STATUS_OK;
+ } else if (key_view == kAuthorizationHeaderKey) {
+ client_options_.middleware.push_back(
+ std::make_shared(std::string(val_view)));
+ return ADBC_STATUS_OK;
+ }
+
+ if (location_.has_value()) {
+ SetError(error, "Database already initialized");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+ options_[std::string(key_view)] = std::string(val_view);
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode Disconnect(struct AdbcError* error) {
+ std::lock_guard guard(mutex_);
+ if (--connection_count_ < 0) {
+ SetError(error, "Connection count underflow");
+ return ADBC_STATUS_INTERNAL;
+ }
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode Release(struct AdbcError* error) {
+ std::lock_guard guard(mutex_);
+
+ if (connection_count_ > 0) {
+ SetError(error, "Cannot release database with ", connection_count_,
+ " open connections");
+ return ADBC_STATUS_INTERNAL;
+ }
+
+ return ADBC_STATUS_OK;
+ }
+
+ private:
+ std::optional location_;
+ FlightClientOptions client_options_ = DefaultClientOptions();
+ std::shared_ptr quirks_ = std::make_shared();
+ std::unordered_map call_headers_;
+ std::unordered_map options_;
+ std::mutex mutex_;
+ int connection_count_ = 0;
+};
+
+/// \brief A RecordBatchReader that reads the endpoints of a FlightInfo
+class FlightInfoReader : public RecordBatchReader {
+ public:
+ explicit FlightInfoReader(FlightSqlClient* client, FlightCallOptions call_options,
+ std::unique_ptr info)
+ : client_(client),
+ call_options_(std::move(call_options)),
+ info_(std::move(info)),
+ next_endpoint_(0) {}
+
+ std::shared_ptr schema() const override { return schema_; }
+
+ Status ReadNext(std::shared_ptr* batch) override {
+ FlightStreamChunk chunk;
+ while (current_stream_ && !chunk.data) {
+ ARROW_ASSIGN_OR_RAISE(chunk, current_stream_->Next());
+ if (chunk.data) {
+ *batch = chunk.data;
+ break;
+ }
+ if (!chunk.data && !chunk.app_metadata) {
+ RETURN_NOT_OK(NextStream());
+ }
+ }
+ if (!current_stream_) *batch = nullptr;
+ return Status::OK();
+ }
+
+ Status Close() override {
+ if (current_stream_) {
+ current_stream_->Cancel();
+ current_stream_.reset();
+ }
+ return Status::OK();
+ }
+
+ AdbcStatusCode Init(struct AdbcError* error) {
+ ADBC_ARROW_RETURN_NOT_OK(error, NextStream());
+ if (!schema_) {
+ // Empty result set - fall back on schema in FlightInfo
+ ipc::DictionaryMemo memo;
+ ADBC_ARROW_RETURN_NOT_OK(error, info_->GetSchema(&memo).Value(&schema_));
+ }
+ return ADBC_STATUS_OK;
+ }
+
+ /// \brief Export to an ArrowArrayStream
+ static AdbcStatusCode Export(FlightSqlClient* client, FlightCallOptions call_options,
+ std::unique_ptr info,
+ struct ArrowArrayStream* stream, struct AdbcError* error) {
+ auto reader = std::make_shared(client, std::move(call_options),
+ std::move(info));
+ ADBC_RETURN_NOT_OK(reader->Init(error));
+ ADBC_ARROW_RETURN_NOT_OK(error, ExportRecordBatchReader(std::move(reader), stream));
+ return ADBC_STATUS_OK;
+ }
+
+ private:
+ Status NextStream() {
+ if (next_endpoint_ >= info_->endpoints().size()) {
+ current_stream_ = nullptr;
+ return Status::OK();
+ }
+ const FlightEndpoint& endpoint = info_->endpoints()[next_endpoint_];
+
+ if (endpoint.locations.empty()) {
+ ARROW_ASSIGN_OR_RAISE(current_stream_,
+ client_->DoGet(call_options_, endpoint.ticket));
+ } else {
+ // TODO(lidavidm): this should come from a connection pool
+ std::string failures;
+ current_stream_ = nullptr;
+ for (const Location& location : endpoint.locations) {
+ auto status =
+ FlightClient::Connect(location, DefaultClientOptions()).Value(&data_client_);
+ if (status.ok()) {
+ status =
+ data_client_->DoGet(call_options_, endpoint.ticket).Value(¤t_stream_);
+ }
+ if (!status.ok()) {
+ if (!failures.empty()) {
+ failures += "; ";
+ }
+ failures += location.ToString();
+ failures += ": ";
+ failures += status.ToString();
+ data_client_.reset();
+ continue;
+ }
+ break;
+ }
+
+ if (!current_stream_) {
+ return Status::IOError("Failed to connect to all endpoints: ", failures);
+ }
+ }
+ next_endpoint_++;
+ if (!schema_) {
+ ARROW_ASSIGN_OR_RAISE(schema_, current_stream_->GetSchema());
+ }
+ return Status::OK();
+ }
+
+ FlightSqlClient* client_;
+ FlightCallOptions call_options_;
+ std::unique_ptr info_;
+ size_t next_endpoint_;
+ std::shared_ptr schema_;
+ std::unique_ptr current_stream_;
+ // TODO(lidavidm): use a common pool of cached clients with expiration
+ std::unique_ptr data_client_;
+};
+
+class FlightSqlConnectionImpl {
+ public:
+ //----------------------------------------------------------
+ // Common Functions
+ //----------------------------------------------------------
+
+ FlightSqlClient* client() const { return client_.get(); }
+ const FlightSqlQuirks& quirks() const { return *quirks_; }
+ const Transaction& transaction() const { return transaction_; }
+ FlightCallOptions MakeCallOptions(CallContext context) const {
+ FlightCallOptions options = database_->MakeCallOptions(context);
+ auto it = timeout_seconds_.find(context);
+ if (it != timeout_seconds_.end()) {
+ options.timeout = it->second;
+ }
+ for (const auto& header : call_headers_) {
+ options.headers.emplace_back(header.first, header.second);
+ }
+ return options;
+ }
+
+ AdbcStatusCode Init(struct AdbcDatabase* database, struct AdbcError* error) {
+ if (!database->private_data) {
+ SetError(error, "Cannot create connection from uninitialized database");
+ return ADBC_STATUS_INVALID_STATE;
+ } else if (client_) {
+ SetError(error, "Already initialized connection");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+
+ database_ = *reinterpret_cast*>(
+ database->private_data);
+ ADBC_RETURN_NOT_OK(database_->Connect(&client_, error));
+ quirks_ = database_->quirks();
+
+ // Ignore this if it fails - we'll just proceed
+ ARROW_UNUSED(QuerySqlInfo());
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode Close(struct AdbcError* error) {
+ AdbcStatusCode return_status = ADBC_STATUS_OK;
+ if (client_) {
+ auto status = client_->Close();
+ client_.reset();
+ if (!status.ok()) {
+ SetError(error, status);
+ return_status = ADBC_STATUS_IO;
+ }
+
+ if (database_) {
+ ADBC_RETURN_NOT_OK(database_->Disconnect(error));
+ }
+ }
+
+ return return_status;
+ }
+
+ AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error) {
+ if (key == nullptr) {
+ SetError(error, "Key must not be null");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+
+ auto set_timeout_option = [=](CallContext context) -> AdbcStatusCode {
+ double timeout = 0.0;
+ const size_t len = std::strlen(value);
+ if (!arrow::internal::StringToFloat(value, len, /*decimal_point=*/'.', &timeout)) {
+ SetError(error, "Invalid timeout option value ", key, '=', value,
+ ": invalid floating point value");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ if (std::isnan(timeout) || std::isinf(timeout) || timeout < 0) {
+ SetError(error, "Invalid timeout option value ", key, '=', value,
+ ": timeout must be positive and finite");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+
+ if (timeout == 0) {
+ timeout_seconds_.erase(context);
+ } else {
+ timeout_seconds_[context] = std::chrono::duration(timeout);
+ }
+ return ADBC_STATUS_OK;
+ };
+
+ std::string_view key_view(key);
+ std::string_view val_view = value ? value : "";
+ if (key == kConnectionOptionAutocommit) {
+ if (val_view == kOptionValueEnabled && !transaction_.is_valid()) {
+ // No-op - don't error even if the server didn't support transactions
+ return ADBC_STATUS_OK;
+ }
+ ADBC_RETURN_NOT_OK(CheckTransactionSupport(error));
+ FlightCallOptions call_options = MakeCallOptions(CallContext::kUpdate);
+ if (val_view == kOptionValueEnabled) {
+ if (transaction_.is_valid()) {
+ ADBC_ARROW_RETURN_NOT_OK(error, client_->Commit(call_options, transaction_));
+ transaction_ = no_transaction();
+ }
+ return ADBC_STATUS_OK;
+ } else if (val_view == kOptionValueDisabled) {
+ if (transaction_.is_valid()) {
+ ADBC_ARROW_RETURN_NOT_OK(error, client_->Commit(call_options, transaction_));
+ }
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, client_->BeginTransaction(call_options).Value(&transaction_));
+ return ADBC_STATUS_OK;
+ }
+ SetError(error, "Invalid connection option value ", key_view, '=', val_view);
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ } else if (key == kConnectionTimeoutFetchKey) {
+ return set_timeout_option(CallContext::kFetch);
+ } else if (key == kConnectionTimeoutQueryKey) {
+ return set_timeout_option(CallContext::kQuery);
+ } else if (key == kConnectionTimeoutUpdateKey) {
+ return set_timeout_option(CallContext::kUpdate);
+ } else if (key_view.rfind(kCallHeaderPrefix, 0) == 0) {
+ std::string header(key_view.substr(kCallHeaderPrefix.size()));
+ if (value == nullptr) {
+ call_headers_.erase(header);
+ } else {
+ call_headers_.insert({std::move(header), value});
+ }
+ return ADBC_STATUS_OK;
+ }
+ SetError(error, "Unknown connection option ", key_view, '=', val_view);
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+
+ //----------------------------------------------------------
+ // Metadata
+ //----------------------------------------------------------
+
+ AdbcStatusCode GetInfo(uint32_t* info_codes, size_t info_codes_length,
+ struct ArrowArrayStream* stream, struct AdbcError* error) {
+ static std::shared_ptr kInfoSchema = arrow::schema({
+ arrow::field("info_name", arrow::uint32(), /*nullable=*/false),
+ arrow::field(
+ "info_value",
+ arrow::dense_union({
+ arrow::field("string_value", arrow::utf8()),
+ arrow::field("bool_value", arrow::boolean()),
+ arrow::field("int64_value", arrow::int64()),
+ arrow::field("int32_bitmask", arrow::int32()),
+ arrow::field("string_list", arrow::list(arrow::utf8())),
+ arrow::field("int32_to_int32_list_map",
+ arrow::map(arrow::int32(), arrow::list(arrow::int32()))),
+ })),
+ });
+
+ // XXX(ARROW-17558): type should be uint32_t not int
+ std::vector flight_sql_codes;
+ std::vector codes;
+ if (info_codes && info_codes_length > 0) {
+ for (size_t i = 0; i < info_codes_length; i++) {
+ const uint32_t info_code = info_codes[i];
+ switch (info_code) {
+ case ADBC_INFO_VENDOR_NAME:
+ case ADBC_INFO_VENDOR_VERSION:
+ case ADBC_INFO_VENDOR_ARROW_VERSION:
+ // These codes are equivalent between the two
+ flight_sql_codes.push_back(info_code);
+ break;
+ case ADBC_INFO_DRIVER_NAME:
+ case ADBC_INFO_DRIVER_VERSION:
+ case ADBC_INFO_DRIVER_ARROW_VERSION:
+ codes.push_back(info_code);
+ break;
+ default:
+ SetError(error, "Unknown info code: ", info_code);
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ }
+ } else {
+ flight_sql_codes = {
+ SqlInfoOptions::FLIGHT_SQL_SERVER_NAME,
+ SqlInfoOptions::FLIGHT_SQL_SERVER_VERSION,
+ SqlInfoOptions::FLIGHT_SQL_SERVER_ARROW_VERSION,
+ };
+ codes = {
+ ADBC_INFO_DRIVER_NAME,
+ ADBC_INFO_DRIVER_VERSION,
+ ADBC_INFO_DRIVER_ARROW_VERSION,
+ };
+ }
+
+ RecordBatchVector result;
+
+ UInt32Builder names;
+ std::unique_ptr values;
+ ADBC_ARROW_RETURN_NOT_OK(error,
+ MakeBuilder(kInfoSchema->field(1)->type()).Value(&values));
+ auto* info_value = static_cast(values.get());
+ auto* info_string = static_cast(info_value->child_builder(0).get());
+ int64_t num_values = 0;
+
+ constexpr int8_t kStringCode = 0;
+
+ if (!flight_sql_codes.empty()) {
+ FlightCallOptions call_options = MakeCallOptions(CallContext::kQuery);
+ std::unique_ptr info;
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, client_->GetSqlInfo(call_options, flight_sql_codes).Value(&info));
+ FlightInfoReader reader(client_.get(), MakeCallOptions(CallContext::kFetch),
+ std::move(info));
+ ADBC_RETURN_NOT_OK(reader.Init(error));
+
+ if (!reader.schema()->Equals(*SqlSchema::GetSqlInfoSchema())) {
+ SetError(error, "Server returned wrong schema, got: ", *reader.schema());
+ return ADBC_STATUS_INTERNAL;
+ }
+
+ while (true) {
+ std::shared_ptr batch;
+ ADBC_ARROW_RETURN_NOT_OK(error, reader.Next().Value(&batch));
+ if (!batch) break;
+
+ const auto& sql_codes = checked_cast(*batch->column(0));
+ const auto& sql_value = checked_cast(*batch->column(1));
+ const auto& sql_string =
+ checked_cast(*sql_value.field(kStringCode));
+ for (int64_t i = 0; i < batch->num_rows(); i++) {
+ // Shouldn't happen but oh well
+ if (!sql_codes.IsValid(i)) continue;
+
+ switch (sql_codes.Value(i)) {
+ case SqlInfoOptions::FLIGHT_SQL_SERVER_NAME:
+ case SqlInfoOptions::FLIGHT_SQL_SERVER_VERSION:
+ case SqlInfoOptions::FLIGHT_SQL_SERVER_ARROW_VERSION: {
+ // These should all be string values where the codes are
+ // equivalent between ADBC/Flight SQL
+ ADBC_ARROW_RETURN_NOT_OK(error, names.Append(sql_codes.Value(i)));
+ if (sql_value.type_code(i) != kStringCode) {
+ SetError(error, "Server returned wrong type for info value ",
+ sql_codes.Value(i));
+ return ADBC_STATUS_INTERNAL;
+ }
+ ADBC_ARROW_RETURN_NOT_OK(
+ error,
+ info_string->Append(sql_string.GetString(sql_value.value_offset(i))));
+ ADBC_ARROW_RETURN_NOT_OK(error, info_value->Append(kStringCode));
+ num_values++;
+ break;
+ }
+ default:
+ // Ignore if the server returns something unknown
+ continue;
+ }
+ }
+ }
+ }
+
+ if (!codes.empty()) {
+ for (const uint32_t code : codes) {
+ switch (code) {
+ case ADBC_INFO_DRIVER_NAME:
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, info_string->Append("Apache Arrow Flight SQL ADBC Driver"));
+ ADBC_ARROW_RETURN_NOT_OK(error, info_value->Append(kStringCode));
+ break;
+ case ADBC_INFO_DRIVER_VERSION:
+ case ADBC_INFO_DRIVER_ARROW_VERSION:
+ ADBC_ARROW_RETURN_NOT_OK(error,
+ info_string->Append(GetBuildInfo().version_string));
+ ADBC_ARROW_RETURN_NOT_OK(error, info_value->Append(kStringCode));
+ break;
+ default:
+ SetError(error, "Unknown info code: ", code);
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ ADBC_ARROW_RETURN_NOT_OK(error, names.Append(code));
+ num_values++;
+ }
+ }
+
+ ArrayVector cols(2);
+ ADBC_ARROW_RETURN_NOT_OK(error, names.Finish(&cols[0]));
+ ADBC_ARROW_RETURN_NOT_OK(error, values->Finish(&cols[1]));
+ result.push_back(RecordBatch::Make(kInfoSchema, num_values, std::move(cols)));
+
+ ADBC_ARROW_RETURN_NOT_OK(error,
+ ExportRecordBatches(kInfoSchema, std::move(result), stream));
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode GetObjects(int depth, const char* catalog, const char* db_schema,
+ const char* table_name, const char** table_type,
+ const char* column_name, struct ArrowArrayStream* stream,
+ struct AdbcError* error) {
+ static std::shared_ptr kColumnSchema = arrow::struct_({
+ arrow::field("column_name", arrow::utf8(), /*nullable=*/false),
+ arrow::field("ordinal_position", arrow::int32()),
+ arrow::field("remarks", arrow::utf8()),
+ arrow::field("xdbc_data_type", arrow::int16()),
+ arrow::field("xdbc_type_name", arrow::utf8()),
+ arrow::field("xdbc_column_size", arrow::int32()),
+ arrow::field("xdbc_decimal_digits", arrow::int16()),
+ arrow::field("xdbc_num_prec_radix", arrow::int16()),
+ arrow::field("xdbc_nullable", arrow::int16()),
+ arrow::field("xdbc_column_def", arrow::utf8()),
+ arrow::field("xdbc_sql_data_type", arrow::int16()),
+ arrow::field("xdbc_datetime_sub", arrow::int16()),
+ arrow::field("xdbc_char_octet_length", arrow::int32()),
+ arrow::field("xdbc_is_nullable", arrow::utf8()),
+ arrow::field("xdbc_scope_catalog", arrow::utf8()),
+ arrow::field("xdbc_scope_schema", arrow::utf8()),
+ arrow::field("xdbc_scope_table", arrow::utf8()),
+ arrow::field("xdbc_is_autoincrement", arrow::boolean()),
+ arrow::field("xdbc_is_generatedcolumn", arrow::boolean()),
+ });
+ static std::shared_ptr kUsageSchema = arrow::struct_({
+ arrow::field("fk_catalog", arrow::utf8()),
+ arrow::field("fk_db_schema", arrow::utf8()),
+ arrow::field("fk_table", arrow::utf8(), /*nullable=*/false),
+ arrow::field("fk_column_name", arrow::utf8(), /*nullable=*/false),
+ });
+ static std::shared_ptr kConstraintSchema = arrow::struct_({
+ arrow::field("constraint_name", arrow::utf8()),
+ arrow::field("constraint_type", arrow::utf8(), /*nullable=*/false),
+ arrow::field("constraint_column_names", arrow::list(arrow::utf8()),
+ /*nullable=*/false),
+ arrow::field("constraint_column_usage", arrow::list(kUsageSchema)),
+ });
+ static std::shared_ptr kTableSchema = arrow::struct_({
+ arrow::field("table_name", arrow::utf8(), /*nullable=*/false),
+ arrow::field("table_type", arrow::utf8(), /*nullable=*/false),
+ arrow::field("table_columns", arrow::list(kColumnSchema)),
+ arrow::field("table_constraints", arrow::list(kConstraintSchema)),
+ });
+ static std::shared_ptr kDbSchemaSchema = arrow::struct_({
+ arrow::field("db_schema_name", arrow::utf8()),
+ arrow::field("db_schema_tables", arrow::list(kTableSchema)),
+ });
+ static std::shared_ptr kCatalogSchema = arrow::schema({
+ arrow::field("catalog_name", arrow::utf8()),
+ arrow::field("catalog_db_schemas", arrow::list(kDbSchemaSchema)),
+ });
+ FlightCallOptions call_options = MakeCallOptions(CallContext::kQuery);
+
+ // To avoid an N+1 query problem, we assume result sets here will
+ // fit in memory and build up a single response.
+
+ std::string db_schema_filter = db_schema ? db_schema : "";
+ std::string table_name_filter = table_name ? table_name : "";
+ std::vector table_type_filter;
+ if (table_type) {
+ while (*table_type) {
+ table_type_filter.emplace_back(*table_type);
+ table_type++;
+ }
+ }
+
+ TableIndex table_index;
+ std::shared_ptr table_data;
+
+ if (depth == ADBC_OBJECT_DEPTH_ALL || depth >= ADBC_OBJECT_DEPTH_DB_SCHEMAS) {
+ std::unique_ptr info;
+ ADBC_ARROW_RETURN_NOT_OK(error,
+ client_
+ ->GetDbSchemas(call_options, /*catalog=*/nullptr,
+ db_schema ? &db_schema_filter : nullptr)
+ .Value(&info));
+ FlightInfoReader reader(client_.get(), MakeCallOptions(CallContext::kFetch),
+ std::move(info));
+ ADBC_RETURN_NOT_OK(reader.Init(error));
+ ADBC_ARROW_RETURN_NOT_OK(error, IndexDbSchemas(&reader, &table_index));
+ }
+ if (depth == ADBC_OBJECT_DEPTH_ALL || depth >= ADBC_OBJECT_DEPTH_TABLES) {
+ std::unique_ptr info;
+ ADBC_ARROW_RETURN_NOT_OK(error,
+ client_
+ ->GetTables(call_options, /*catalog=*/nullptr,
+ db_schema ? &db_schema_filter : nullptr,
+ table_name ? &table_name_filter : nullptr,
+ /*include_schemas=*/true,
+ table_type ? &table_type_filter : nullptr)
+ .Value(&info));
+ FlightInfoReader reader(client_.get(), MakeCallOptions(CallContext::kFetch),
+ std::move(info));
+ ADBC_RETURN_NOT_OK(reader.Init(error));
+ ADBC_ARROW_RETURN_NOT_OK(error, IndexTables(&reader, &table_index, &table_data));
+ }
+
+ std::unique_ptr raw_builder;
+ ADBC_ARROW_RETURN_NOT_OK(error, MakeBuilder(kDbSchemaSchema).Value(&raw_builder));
+ StringBuilder catalog_name_builder;
+ ListBuilder catalog_db_schemas_builder(default_memory_pool(), std::move(raw_builder));
+ auto* catalog_db_schemas_items_builder =
+ checked_cast(catalog_db_schemas_builder.value_builder());
+ auto* db_schema_name_builder =
+ checked_cast(catalog_db_schemas_items_builder->field_builder(0));
+ auto* db_schema_tables_builder =
+ checked_cast(catalog_db_schemas_items_builder->field_builder(1));
+ auto* db_schema_tables_items_builder =
+ checked_cast(db_schema_tables_builder->value_builder());
+ auto* table_name_builder =
+ checked_cast(db_schema_tables_items_builder->field_builder(0));
+ auto* table_type_builder =
+ checked_cast(db_schema_tables_items_builder->field_builder(1));
+ auto* table_columns_builder =
+ checked_cast(db_schema_tables_items_builder->field_builder(2));
+ auto* table_constraints_builder =
+ checked_cast(db_schema_tables_items_builder->field_builder(3));
+ auto* table_columns_items_builder =
+ checked_cast(table_columns_builder->value_builder());
+ auto* column_name_builder =
+ checked_cast(table_columns_items_builder->field_builder(0));
+ auto* ordinal_position_builder =
+ checked_cast(table_columns_items_builder->field_builder(1));
+
+ std::unique_ptr info;
+ ADBC_ARROW_RETURN_NOT_OK(error, client_->GetCatalogs(call_options).Value(&info));
+ FlightInfoReader catalog_reader(client_.get(), MakeCallOptions(CallContext::kFetch),
+ std::move(info));
+ ADBC_RETURN_NOT_OK(catalog_reader.Init(error));
+
+ ReaderIterator catalogs(*SqlSchema::GetCatalogsSchema(), &catalog_reader);
+ ADBC_ARROW_RETURN_NOT_OK(error, catalogs.Init());
+ while (true) {
+ bool have_data = false;
+ ADBC_ARROW_RETURN_NOT_OK(error, catalogs.Next().Value(&have_data));
+ if (!have_data) break;
+
+ std::optional cur_catalog_name(catalogs.GetNullable(0));
+ // TODO(lidavidm): catalog is a filter string (evaluate with compute fn if
+ // available)
+ if (catalog && cur_catalog_name != catalog) continue;
+
+ if (cur_catalog_name) {
+ ADBC_ARROW_RETURN_NOT_OK(error, catalog_name_builder.Append(*cur_catalog_name));
+ } else {
+ ADBC_ARROW_RETURN_NOT_OK(error, catalog_name_builder.AppendNull());
+ }
+
+ if (depth == ADBC_OBJECT_DEPTH_ALL || depth >= ADBC_OBJECT_DEPTH_DB_SCHEMAS) {
+ ADBC_ARROW_RETURN_NOT_OK(error, catalog_db_schemas_builder.Append());
+ } else {
+ ADBC_ARROW_RETURN_NOT_OK(error, catalog_db_schemas_builder.AppendNull());
+ continue;
+ }
+
+ auto it = table_index.find(cur_catalog_name);
+ if (it == table_index.end()) continue;
+
+ for (const auto& schema_item : it->second) {
+ ADBC_ARROW_RETURN_NOT_OK(error, catalog_db_schemas_items_builder->Append());
+ const std::optional& cur_schema_name = schema_item.first;
+
+ if (cur_schema_name) {
+ ADBC_ARROW_RETURN_NOT_OK(error,
+ db_schema_name_builder->Append(*cur_schema_name));
+ } else {
+ ADBC_ARROW_RETURN_NOT_OK(error, db_schema_name_builder->AppendNull());
+ }
+
+ if (depth == ADBC_OBJECT_DEPTH_ALL || depth >= ADBC_OBJECT_DEPTH_TABLES) {
+ ADBC_ARROW_RETURN_NOT_OK(error, db_schema_tables_builder->Append());
+ } else {
+ ADBC_ARROW_RETURN_NOT_OK(error, db_schema_tables_builder->AppendNull());
+ continue;
+ }
+
+ std::pair schema_tables_range = schema_item.second;
+ for (int64_t i = schema_tables_range.first; i < schema_tables_range.second; i++) {
+ const std::string_view table_name = GetNotNull(*table_data, 2, i);
+ const std::string_view table_type = GetNotNull(*table_data, 3, i);
+ const std::string_view table_schema = GetNotNull(*table_data, 4, i);
+
+ ADBC_ARROW_RETURN_NOT_OK(error, db_schema_tables_items_builder->Append());
+ ADBC_ARROW_RETURN_NOT_OK(error, table_name_builder->Append(table_name));
+ ADBC_ARROW_RETURN_NOT_OK(error, table_type_builder->Append(table_type));
+
+ if (depth == ADBC_OBJECT_DEPTH_COLUMNS) {
+ ADBC_ARROW_RETURN_NOT_OK(error, table_columns_builder->Append());
+
+ std::shared_ptr schema;
+ io::BufferReader reader((Buffer(table_schema)));
+ ipc::DictionaryMemo memo;
+ ADBC_ARROW_RETURN_NOT_OK(error,
+ ipc::ReadSchema(&reader, &memo).Value(&schema));
+
+ int32_t ordinal_position = 1;
+ for (const std::shared_ptr& field : schema->fields()) {
+#ifdef ARROW_COMPUTE
+ if (column_name && std::strlen(column_name) > 0) {
+ Datum result;
+ arrow::compute::MatchSubstringOptions options(column_name);
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, arrow::compute::CallFunction(
+ "match_like", {Datum(MakeScalar(field->name()))}, &options)
+ .Value(&result));
+ DCHECK_EQ(result.kind(), Datum::Kind::SCALAR);
+ if (!result.scalar_as().value) continue;
+ }
+#else
+ if (column_name && std::strlen(column_name) > 0) {
+ SetError(error, "Cannot filter on column name without ARROW_COMPUTE");
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+#endif
+ ADBC_ARROW_RETURN_NOT_OK(error, table_columns_items_builder->Append());
+ ADBC_ARROW_RETURN_NOT_OK(error, column_name_builder->Append(field->name()));
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, ordinal_position_builder->Append(ordinal_position++));
+ for (int i = 2; i < table_columns_items_builder->num_fields(); i++) {
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, table_columns_items_builder->field_builder(i)->AppendNull());
+ }
+ }
+
+ // TODO(lidavidm): unimplemented for now
+ ADBC_ARROW_RETURN_NOT_OK(error, table_constraints_builder->Append());
+ } else {
+ ADBC_ARROW_RETURN_NOT_OK(error, table_columns_builder->AppendNull());
+ ADBC_ARROW_RETURN_NOT_OK(error, table_constraints_builder->AppendNull());
+ }
+ }
+ }
+ }
+
+ ArrayVector arrays(2);
+ ADBC_ARROW_RETURN_NOT_OK(error, catalog_name_builder.Finish(&arrays[0]));
+ ADBC_ARROW_RETURN_NOT_OK(error, catalog_db_schemas_builder.Finish(&arrays[1]));
+ const int64_t num_rows = arrays[0]->length();
+ auto batch = RecordBatch::Make(kCatalogSchema, num_rows, std::move(arrays));
+
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, ExportRecordBatches(kCatalogSchema, {std::move(batch)}, stream));
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode GetTableSchema(const char* catalog, const char* db_schema,
+ const char* table_name, struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ FlightCallOptions call_options = MakeCallOptions(CallContext::kQuery);
+ std::string catalog_str, db_schema_str, table_name_str;
+
+ if (catalog) catalog_str = catalog;
+ if (db_schema) db_schema_str = db_schema;
+ if (table_name) table_name_str = table_name;
+
+ std::unique_ptr info;
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, client_
+ ->GetTables(call_options, catalog ? &catalog_str : nullptr,
+ db_schema ? &db_schema_str : nullptr,
+ table_name ? &table_name_str : nullptr,
+ /*include_schema=*/true, /*table_types=*/nullptr)
+ .Value(&info));
+
+ FlightInfoReader reader(client_.get(), MakeCallOptions(CallContext::kFetch),
+ std::move(info));
+ ADBC_RETURN_NOT_OK(reader.Init(error));
+
+ if (!reader.schema()->Equals(*SqlSchema::GetTablesSchemaWithIncludedSchema())) {
+ SetError(error, "Server returned wrong schema, got:\n", *reader.schema(),
+ "\nExpected:\n", *SqlSchema::GetTablesSchemaWithIncludedSchema());
+ return ADBC_STATUS_INTERNAL;
+ }
+
+ while (true) {
+ std::shared_ptr batch;
+ ADBC_ARROW_RETURN_NOT_OK(error, reader.Next().Value(&batch));
+ if (!batch) break;
+ if (batch->num_rows() == 0) continue;
+
+ const auto& schema_col = batch->column(4);
+ if (schema_col->type()->id() != Type::BINARY) {
+ SetError(error, "Server returned invalid schema; expected binary, found ",
+ *schema_col->type());
+ return ADBC_STATUS_INTERNAL;
+ }
+ const auto& binary = checked_cast(*schema_col);
+ if (!binary.IsValid(0)) {
+ SetError(error, "Schema was null though field is non-null");
+ return ADBC_STATUS_INTERNAL;
+ }
+
+ ipc::DictionaryMemo memo;
+ io::BufferReader stream(Buffer::FromString(binary.GetString(0)));
+ std::shared_ptr arrow_schema;
+ ADBC_ARROW_RETURN_NOT_OK(error,
+ ipc::ReadSchema(&stream, &memo).Value(&arrow_schema));
+ ADBC_ARROW_RETURN_NOT_OK(error, arrow::ExportSchema(*arrow_schema, schema));
+ return ADBC_STATUS_OK;
+ }
+
+ SetError(error, "No table found meeting criteria");
+ return ADBC_STATUS_NOT_FOUND;
+ }
+
+ AdbcStatusCode GetTableTypes(struct ArrowArrayStream* stream, struct AdbcError* error) {
+ FlightCallOptions call_options = MakeCallOptions(CallContext::kQuery);
+ std::unique_ptr flight_info;
+ auto status = client_->GetTableTypes(call_options).Value(&flight_info);
+ if (!status.ok()) {
+ SetError(error, status);
+ return ADBC_STATUS_IO;
+ }
+ return FlightInfoReader::Export(client_.get(), MakeCallOptions(CallContext::kFetch),
+ std::move(flight_info), stream, error);
+ }
+
+ //----------------------------------------------------------
+ // Partitioned Results
+ //----------------------------------------------------------
+
+ AdbcStatusCode ReadPartition(const uint8_t* serialized_partition,
+ size_t serialized_length, struct ArrowArrayStream* out,
+ struct AdbcError* error) {
+ std::unique_ptr info;
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, FlightInfo::Deserialize(
+ std::string_view(reinterpret_cast(serialized_partition),
+ serialized_length))
+ .Value(&info));
+ return FlightInfoReader::Export(client_.get(), MakeCallOptions(CallContext::kFetch),
+ std::move(info), out, error);
+ }
+
+ //----------------------------------------------------------
+ // Transactions
+ //----------------------------------------------------------
+
+ AdbcStatusCode Commit(struct AdbcError* error) {
+ if (transaction_.is_valid()) {
+ ADBC_RETURN_NOT_OK(CheckTransactionSupport(error));
+ FlightCallOptions call_options = MakeCallOptions(CallContext::kUpdate);
+ ADBC_ARROW_RETURN_NOT_OK(error, client_->Commit(call_options, transaction_));
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, client_->BeginTransaction(call_options).Value(&transaction_));
+ return ADBC_STATUS_OK;
+ }
+ SetError(error, "Cannot commit when autocommit is enabled");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+
+ AdbcStatusCode Rollback(struct AdbcError* error) {
+ if (transaction_.is_valid()) {
+ ADBC_RETURN_NOT_OK(CheckTransactionSupport(error));
+ FlightCallOptions call_options = MakeCallOptions(CallContext::kUpdate);
+ ADBC_ARROW_RETURN_NOT_OK(error, client_->Rollback(call_options, transaction_));
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, client_->BeginTransaction(call_options).Value(&transaction_));
+ return ADBC_STATUS_OK;
+ }
+ SetError(error, "Cannot rollback when autocommit is enabled");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+
+ private:
+ // Check SqlInfo to determine what the server does/doesn't support.
+ Status QuerySqlInfo() {
+ FlightCallOptions call_options = MakeCallOptions(CallContext::kQuery);
+ ARROW_ASSIGN_OR_RAISE(
+ auto info, client_->GetSqlInfo(call_options,
+ {
+ SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION,
+ }));
+
+ FlightInfoReader reader(client_.get(), MakeCallOptions(CallContext::kFetch),
+ std::move(info));
+ struct AdbcError error = {};
+ if (reader.Init(&error) != ADBC_STATUS_OK) {
+ std::string message = "Could not initialize reader: ";
+ if (error.message) {
+ message += error.message;
+ }
+ if (error.release) {
+ error.release(&error);
+ }
+ return Status::IOError(std::move(message));
+ }
+
+ constexpr int8_t kInt32Code = 3;
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(auto batch, reader.Next());
+ if (!batch) break;
+
+ const auto& sql_codes = checked_cast(*batch->column(0));
+ const auto& sql_value = checked_cast(*batch->column(1));
+ const auto& sql_int32 =
+ checked_cast(*sql_value.field(kInt32Code));
+ for (int64_t i = 0; i < batch->num_rows(); i++) {
+ if (!sql_codes.IsValid(i)) continue;
+
+ switch (sql_codes.Value(i)) {
+ case SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION: {
+ if (sql_value.type_code(i) != kInt32Code) {
+ continue;
+ }
+ const int32_t idx = sql_value.value_offset(i);
+ if (!sql_int32.IsValid(idx)) {
+ continue;
+ }
+ const int32_t value = sql_int32.IsValid(idx) && sql_int32.Value(idx);
+ support_.transactions =
+ (value == SqlInfoOptions::SQL_SUPPORTED_TRANSACTION_TRANSACTION ||
+ value == SqlInfoOptions::SQL_SUPPORTED_TRANSACTION_SAVEPOINT);
+ break;
+ }
+ default:
+ continue;
+ }
+ }
+ }
+
+ return reader.Close();
+ }
+
+ AdbcStatusCode CheckTransactionSupport(struct AdbcError* error) {
+ if (!support_.transactions) {
+ SetError(error, "Server does not report transaction support");
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+ return ADBC_STATUS_OK;
+ }
+
+ struct {
+ bool transactions = false;
+ } support_;
+ std::shared_ptr database_;
+ std::unique_ptr client_;
+ Transaction transaction_ = no_transaction();
+ std::shared_ptr quirks_;
+ std::unordered_map> timeout_seconds_;
+ std::unordered_map call_headers_;
+};
+
+class FlightSqlPartitionsImpl {
+ public:
+ explicit FlightSqlPartitionsImpl(std::vector partitions)
+ : partitions_(std::move(partitions)) {
+ pointers_.resize(partitions_.size());
+ lengths_.reserve(partitions_.size());
+
+ for (size_t i = 0; i < partitions_.size(); i++) {
+ pointers_[i] = reinterpret_cast(partitions_[i].data());
+ lengths_[i] = partitions_[i].size();
+ }
+ }
+
+ static AdbcStatusCode Export(const FlightInfo& info, struct AdbcPartitions* out,
+ struct AdbcError* error) {
+ std::vector partitions;
+ partitions.reserve(info.endpoints().size());
+ for (const FlightEndpoint& endpoint : info.endpoints()) {
+ FlightInfo partition_info(FlightInfo::Data{
+ info.serialized_schema(),
+ info.descriptor(),
+ {endpoint},
+ /*total_records=*/-1,
+ /*total_bytes=*/-1,
+ });
+ std::string serialized;
+ ADBC_ARROW_RETURN_NOT_OK(error,
+ partition_info.SerializeToString().Value(&serialized));
+ partitions.push_back(std::move(serialized));
+ }
+
+ auto* impl = new FlightSqlPartitionsImpl(std::move(partitions));
+ out->num_partitions = impl->partitions_.size();
+ out->partitions = impl->pointers_.data();
+ out->partition_lengths = impl->lengths_.data();
+ out->private_data = impl;
+ out->release = &Release;
+ return ADBC_STATUS_OK;
+ }
+
+ static void Release(struct AdbcPartitions* partitions) {
+ auto* impl = static_cast(partitions->private_data);
+ delete impl;
+ partitions->num_partitions = 0;
+ partitions->partitions = nullptr;
+ partitions->partition_lengths = nullptr;
+ partitions->private_data = nullptr;
+ partitions->release = nullptr;
+ }
+
+ private:
+ std::vector partitions_;
+ std::vector pointers_;
+ std::vector lengths_;
+};
+
+class FlightSqlStatementImpl {
+ public:
+ explicit FlightSqlStatementImpl(std::shared_ptr connection)
+ : connection_(std::move(connection)) {}
+
+ AdbcStatusCode Bind(struct ArrowArray* array, struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ std::shared_ptr batch;
+ ADBC_ARROW_RETURN_NOT_OK(error, ImportRecordBatch(array, schema).Value(&batch));
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, RecordBatchReader::Make({std::move(batch)}).Value(&bind_parameters_));
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode Bind(struct ArrowArrayStream* stream, struct AdbcError* error) {
+ ADBC_ARROW_RETURN_NOT_OK(error,
+ ImportRecordBatchReader(stream).Value(&bind_parameters_));
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode Close(struct AdbcError* error) {
+ ADBC_RETURN_NOT_OK(ClosePreparedStatement(error));
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode ExecutePartitions(struct ArrowSchema* schema,
+ struct AdbcPartitions* partitions,
+ int64_t* rows_affected, struct AdbcError* error) {
+ std::unique_ptr info;
+ ADBC_ARROW_RETURN_NOT_OK(error, ExecuteFlightInfo().Value(&info));
+ if (rows_affected) *rows_affected = info->total_records();
+
+ ipc::DictionaryMemo memo;
+ std::shared_ptr arrow_schema;
+ ADBC_ARROW_RETURN_NOT_OK(error, info->GetSchema(&memo).Value(&arrow_schema));
+ ADBC_ARROW_RETURN_NOT_OK(error, ExportSchema(*arrow_schema, schema));
+ ADBC_RETURN_NOT_OK(FlightSqlPartitionsImpl::Export(*info, partitions, error));
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode ExecuteQuery(struct ArrowArrayStream* out, int64_t* rows_affected,
+ struct AdbcError* error) {
+ if (!ingest_.target_table.empty()) {
+ if (out) {
+ SetError(error, "Must not provide out for bulk ingest");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ return ExecuteIngest(rows_affected, error);
+ } else if (plan_.plan.empty() && query_.empty() && !prepared_statement_) {
+ SetError(error, "Must provide query");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+
+ if (!out) {
+ int64_t out_rows = 0;
+ ADBC_ARROW_RETURN_NOT_OK(error, ExecuteUpdate().Value(&out_rows));
+ if (rows_affected) *rows_affected = out_rows;
+ return ADBC_STATUS_OK;
+ }
+
+ std::unique_ptr info;
+ ADBC_ARROW_RETURN_NOT_OK(error, ExecuteFlightInfo().Value(&info));
+ if (rows_affected) *rows_affected = info->total_records();
+ return FlightInfoReader::Export(connection_->client(),
+ MakeCallOptions(CallContext::kFetch), std::move(info),
+ out, error);
+ }
+
+ AdbcStatusCode GetParameterSchema(struct ArrowSchema* schema, struct AdbcError* error) {
+ if (!prepared_statement_) {
+ SetError(error, "Must Prepare() before GetParameterSchema()");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, arrow::ExportSchema(*prepared_statement_->parameter_schema(), schema));
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode Prepare(struct AdbcError* error) {
+ if (plan_.plan.empty() && query_.empty()) {
+ SetError(error, "Must provide query");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+ FlightCallOptions call_options = MakeCallOptions(CallContext::kUpdate);
+ ADBC_RETURN_NOT_OK(ClosePreparedStatement(error));
+ if (!plan_.plan.empty()) {
+ DCHECK(query_.empty());
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, connection_->client()
+ ->PrepareSubstrait(call_options, plan_, connection_->transaction())
+ .Value(&prepared_statement_));
+ } else {
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, connection_->client()
+ ->Prepare(call_options, query_, connection_->transaction())
+ .Value(&prepared_statement_));
+ }
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error) {
+ if (key == nullptr) {
+ SetError(error, "Key must not be null");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+
+ std::string_view key_view(key);
+ std::string_view val_view = value ? value : "";
+ if (key_view == kIngestOptionTargetTable) {
+ ADBC_RETURN_NOT_OK(ClearQueryParams(error));
+ ingest_.target_table = val_view;
+ } else if (key_view == kIngestOptionMode) {
+ if (value == kIngestOptionModeCreate) {
+ ingest_.append = false;
+ } else if (value == kIngestOptionModeAppend) {
+ ingest_.append = true;
+ } else {
+ SetError(error, "Invalid statement option value ", key_view, "=", val_view);
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ } else if (key_view == kStatementSubstraitVersionKey) {
+ plan_.version = value;
+ } else if (key_view.rfind(kCallHeaderPrefix, 0) == 0) {
+ std::string header(key_view.substr(kCallHeaderPrefix.size()));
+ if (value == nullptr) {
+ call_headers_.erase(header);
+ } else {
+ call_headers_.insert({std::move(header), std::string(val_view)});
+ }
+ return ADBC_STATUS_OK;
+ } else {
+ SetError(error, "Unknown statement option ", key_view, "=", val_view);
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode SetSqlQuery(const char* query, struct AdbcError* error) {
+ ADBC_RETURN_NOT_OK(ClearQueryParams(error));
+ query_ = query;
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode SetSubstraitPlan(const uint8_t* plan, size_t length,
+ struct AdbcError* error) {
+ ADBC_RETURN_NOT_OK(ClearQueryParams(error));
+ plan_.plan = std::string(reinterpret_cast(plan), length);
+ return ADBC_STATUS_OK;
+ }
+
+ private:
+ AdbcStatusCode ClosePreparedStatement(struct AdbcError* error) {
+ if (prepared_statement_) {
+ FlightCallOptions call_options = MakeCallOptions(CallContext::kUpdate);
+ ADBC_ARROW_RETURN_NOT_OK(error, prepared_statement_->Close(call_options));
+ prepared_statement_ = nullptr;
+ }
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode ClearQueryParams(struct AdbcError* error) {
+ ADBC_RETURN_NOT_OK(ClosePreparedStatement(error));
+ ingest_.target_table.clear();
+ plan_.plan.clear();
+ query_.clear();
+ if (bind_parameters_) {
+ ADBC_ARROW_RETURN_NOT_OK(error, bind_parameters_->Close());
+ bind_parameters_.reset();
+ }
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode ExecuteIngest(int64_t* rows_affected, struct AdbcError* error) {
+ if (!bind_parameters_) {
+ SetError(error, "Must Bind() before bulk ingestion");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+ if (!IsApproximatelyValidIdentifier(ingest_.target_table)) {
+ SetError(error, "Invalid target table ", ingest_.target_table,
+ ": must be alphanumeric with underscores");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+
+ FlightCallOptions call_options = MakeCallOptions(CallContext::kUpdate);
+ if (!ingest_.append) {
+ std::string create = "CREATE TABLE ";
+
+ create += ingest_.target_table;
+ create += " (";
+
+ bool first = true;
+ for (const std::shared_ptr& field : bind_parameters_->schema()->fields()) {
+ if (!first) {
+ create += ", ";
+ }
+ first = false;
+
+ if (!IsApproximatelyValidIdentifier(field->name())) {
+ SetError(error, "Invalid column name ", field->name(),
+ ": must be alphanumeric with underscores");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+
+ create += field->name();
+ const std::unordered_map& mapping =
+ connection_->quirks().ingest_type_mapping;
+ const auto it = mapping.find(field->type()->id());
+ if (it == mapping.end()) {
+ SetError(error, "Data type not supported for bulk ingest: ", field->ToString());
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ } else {
+ create += " ";
+ create += it->second;
+ }
+ }
+ create += ")";
+ SetError(error, "Creating table via: ", create);
+
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, connection_->client()
+ ->ExecuteUpdate(call_options, create, connection_->transaction())
+ .status());
+ }
+
+ std::string append = "INSERT INTO ";
+ {
+ append += ingest_.target_table;
+ append += " VALUES (";
+ for (int i = 0; i < bind_parameters_->schema()->num_fields(); i++) {
+ if (i > 0) append += ", ";
+ // TODO(lidavidm): add config option for parameter symbol (or
+ // query Flight SQL metadata?)
+ append += "?";
+ }
+ append += ")";
+ SetError(error, "Updating table via: ", append);
+ }
+
+ std::shared_ptr stmt;
+ ADBC_ARROW_RETURN_NOT_OK(
+ error, connection_->client()
+ ->Prepare(call_options, append, connection_->transaction())
+ .Value(&stmt));
+
+ int64_t total_rows = 0;
+ ADBC_ARROW_RETURN_NOT_OK(error, stmt->SetParameters(std::move(bind_parameters_)));
+ ADBC_ARROW_RETURN_NOT_OK(error, stmt->ExecuteUpdate().Value(&total_rows));
+ ADBC_ARROW_RETURN_NOT_OK(error, stmt->Close());
+ if (rows_affected) *rows_affected = total_rows;
+ return ADBC_STATUS_OK;
+ }
+
+ arrow::Result ExecuteUpdate() {
+ FlightCallOptions call_options = MakeCallOptions(CallContext::kUpdate);
+ if (prepared_statement_) {
+ if (bind_parameters_) {
+ ARROW_RETURN_NOT_OK(
+ prepared_statement_->SetParameters(std::move(bind_parameters_)));
+ }
+ return prepared_statement_->ExecuteUpdate(call_options);
+ }
+
+ if (!plan_.plan.empty()) {
+ return connection_->client()->ExecuteSubstraitUpdate(call_options, plan_,
+ connection_->transaction());
+ }
+ return connection_->client()->ExecuteUpdate(call_options, query_,
+ connection_->transaction());
+ }
+
+ arrow::Result> ExecuteFlightInfo() {
+ FlightCallOptions call_options = MakeCallOptions(CallContext::kQuery);
+ if (prepared_statement_) {
+ ARROW_RETURN_NOT_OK(
+ prepared_statement_->SetParameters(std::move(bind_parameters_)));
+ return prepared_statement_->Execute(call_options);
+ }
+ if (!plan_.plan.empty()) {
+ return connection_->client()->ExecuteSubstrait(call_options, plan_,
+ connection_->transaction());
+ }
+ return connection_->client()->Execute(call_options, query_,
+ connection_->transaction());
+ }
+
+ FlightCallOptions MakeCallOptions(CallContext context) const {
+ FlightCallOptions options = connection_->MakeCallOptions(context);
+ for (const auto& header : call_headers_) {
+ options.headers.emplace_back(header.first, header.second);
+ }
+ return options;
+ }
+
+ std::shared_ptr connection_;
+ std::shared_ptr prepared_statement_;
+ SubstraitPlan plan_;
+ std::string query_;
+ std::shared_ptr bind_parameters_;
+ std::unordered_map call_headers_;
+ // Bulk ingest state
+ struct {
+ std::string target_table;
+ bool append = false;
+ } ingest_;
+};
+
+AdbcStatusCode FlightSqlDatabaseNew(struct AdbcDatabase* database,
+ struct AdbcError* error) {
+ auto impl = std::make_shared();
+ database->private_data = new std::shared_ptr(impl);
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode FlightSqlDatabaseSetOption(struct AdbcDatabase* database, const char* key,
+ const char* value, struct AdbcError* error) {
+ if (!database || !database->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr =
+ reinterpret_cast*>(database->private_data);
+ return (*ptr)->SetOption(key, value, error);
+}
+
+AdbcStatusCode FlightSqlDatabaseInit(struct AdbcDatabase* database,
+ struct AdbcError* error) {
+ if (!database->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr =
+ reinterpret_cast*>(database->private_data);
+ return (*ptr)->Init(error);
+}
+
+AdbcStatusCode FlightSqlDatabaseRelease(struct AdbcDatabase* database,
+ struct AdbcError* error) {
+ if (!database->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr =
+ reinterpret_cast*>(database->private_data);
+ AdbcStatusCode status = (*ptr)->Release(error);
+ delete ptr;
+ database->private_data = nullptr;
+ return status;
+}
+
+AdbcStatusCode FlightSqlConnectionCommit(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr = reinterpret_cast*>(
+ connection->private_data);
+ return (*ptr)->Commit(error);
+}
+
+AdbcStatusCode FlightSqlConnectionGetInfo(struct AdbcConnection* connection,
+ uint32_t* info_codes, size_t info_codes_length,
+ struct ArrowArrayStream* stream,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr = reinterpret_cast*>(
+ connection->private_data);
+ return (*ptr)->GetInfo(info_codes, info_codes_length, stream, error);
+}
+
+AdbcStatusCode FlightSqlConnectionGetObjects(
+ struct AdbcConnection* connection, int depth, const char* catalog,
+ const char* db_schema, const char* table_name, const char** table_type,
+ const char* column_name, struct ArrowArrayStream* stream, struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr = reinterpret_cast*>(
+ connection->private_data);
+ return (*ptr)->GetObjects(depth, catalog, db_schema, table_name, table_type,
+ column_name, stream, error);
+}
+
+AdbcStatusCode FlightSqlConnectionGetTableSchema(
+ struct AdbcConnection* connection, const char* catalog, const char* db_schema,
+ const char* table_name, struct ArrowSchema* schema, struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr = reinterpret_cast*>(
+ connection->private_data);
+ return (*ptr)->GetTableSchema(catalog, db_schema, table_name, schema, error);
+}
+
+AdbcStatusCode FlightSqlConnectionGetTableTypes(struct AdbcConnection* connection,
+ struct ArrowArrayStream* stream,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr = reinterpret_cast*>(
+ connection->private_data);
+ return (*ptr)->GetTableTypes(stream, error);
+}
+
+AdbcStatusCode FlightSqlConnectionInit(struct AdbcConnection* connection,
+ struct AdbcDatabase* database,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr = reinterpret_cast*>(
+ connection->private_data);
+ return (*ptr)->Init(database, error);
+}
+
+AdbcStatusCode FlightSqlConnectionNew(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ auto impl = std::make_shared();
+ connection->private_data = new std::shared_ptr(impl);
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode FlightSqlConnectionReadPartition(struct AdbcConnection* connection,
+ const uint8_t* serialized_partition,
+ size_t serialized_length,
+ struct ArrowArrayStream* out,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr = reinterpret_cast*>(
+ connection->private_data);
+ return (*ptr)->ReadPartition(serialized_partition, serialized_length, out, error);
+}
+
+AdbcStatusCode FlightSqlConnectionRelease(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr = reinterpret_cast*>(
+ connection->private_data);
+ auto status = (*ptr)->Close(error);
+ delete ptr;
+ connection->private_data = nullptr;
+ return status;
+}
+
+AdbcStatusCode FlightSqlConnectionRollback(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr = reinterpret_cast*>(
+ connection->private_data);
+ return (*ptr)->Rollback(error);
+}
+
+AdbcStatusCode FlightSqlConnectionSetOption(struct AdbcConnection* connection,
+ const char* key, const char* value,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr = reinterpret_cast*>(
+ connection->private_data);
+ return (*ptr)->SetOption(key, value, error);
+}
+
+AdbcStatusCode FlightSqlStatementBind(struct AdbcStatement* statement,
+ struct ArrowArray* values,
+ struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast*>(statement->private_data);
+ return (*ptr)->Bind(values, schema, error);
+}
+
+AdbcStatusCode FlightSqlStatementBindStream(struct AdbcStatement* statement,
+ struct ArrowArrayStream* stream,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast*>(statement->private_data);
+ return (*ptr)->Bind(stream, error);
+}
+
+AdbcStatusCode FlightSqlStatementExecutePartitions(struct AdbcStatement* statement,
+ struct ArrowSchema* schema,
+ struct AdbcPartitions* partitions,
+ int64_t* rows_affected,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast*>(statement->private_data);
+ return (*ptr)->ExecutePartitions(schema, partitions, rows_affected, error);
+}
+
+AdbcStatusCode FlightSqlStatementExecuteQuery(struct AdbcStatement* statement,
+ struct ArrowArrayStream* out,
+ int64_t* rows_affected,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast*>(statement->private_data);
+ return (*ptr)->ExecuteQuery(out, rows_affected, error);
+}
+
+AdbcStatusCode FlightSqlStatementGetParameterSchema(struct AdbcStatement* statement,
+ struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast*>(statement->private_data);
+ return (*ptr)->GetParameterSchema(schema, error);
+}
+
+AdbcStatusCode FlightSqlStatementNew(struct AdbcConnection* connection,
+ struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ auto* ptr = reinterpret_cast*>(
+ connection->private_data);
+ auto impl = std::make_shared(*ptr);
+ statement->private_data = new std::shared_ptr(impl);
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode FlightSqlStatementPrepare(struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast*>(statement->private_data);
+ return (*ptr)->Prepare(error);
+}
+
+AdbcStatusCode FlightSqlStatementRelease(struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast*>(statement->private_data);
+ auto status = (*ptr)->Close(error);
+ delete ptr;
+ statement->private_data = nullptr;
+ return status;
+}
+
+AdbcStatusCode FlightSqlStatementSetOption(struct AdbcStatement* statement,
+ const char* key, const char* value,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast*>(statement->private_data);
+ return (*ptr)->SetOption(key, value, error);
+}
+
+AdbcStatusCode FlightSqlStatementSetSqlQuery(struct AdbcStatement* statement,
+ const char* query, struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast*>(statement->private_data);
+ return (*ptr)->SetSqlQuery(query, error);
+}
+
+AdbcStatusCode FlightSqlStatementSetSubstraitPlan(struct AdbcStatement* statement,
+ const uint8_t* plan, size_t length,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast*>(statement->private_data);
+ return (*ptr)->SetSubstraitPlan(plan, length, error);
+}
+} // namespace
+} // namespace arrow::flight::sql
+
+AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlDatabaseInit(database, error);
+}
+
+AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlDatabaseNew(database, error);
+}
+
+AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key,
+ const char* value, struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlDatabaseSetOption(database, key, value, error);
+}
+
+AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlDatabaseRelease(database, error);
+}
+
+AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlConnectionCommit(connection, error);
+}
+
+AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection,
+ uint32_t* info_codes, size_t info_codes_length,
+ struct ArrowArrayStream* stream,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlConnectionGetInfo(connection, info_codes,
+ info_codes_length, stream, error);
+}
+
+AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth,
+ const char* catalog, const char* db_schema,
+ const char* table_name, const char** table_type,
+ const char* column_name,
+ struct ArrowArrayStream* stream,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlConnectionGetObjects(
+ connection, depth, catalog, db_schema, table_name, table_type, column_name, stream,
+ error);
+}
+
+AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
+ const char* catalog, const char* db_schema,
+ const char* table_name,
+ struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlConnectionGetTableSchema(
+ connection, catalog, db_schema, table_name, schema, error);
+}
+
+AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
+ struct ArrowArrayStream* stream,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlConnectionGetTableTypes(connection, stream, error);
+}
+
+AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
+ struct AdbcDatabase* database,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlConnectionInit(connection, database, error);
+}
+
+AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlConnectionNew(connection, error);
+}
+
+AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection,
+ const uint8_t* serialized_partition,
+ size_t serialized_length,
+ struct ArrowArrayStream* out,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlConnectionReadPartition(
+ connection, serialized_partition, serialized_length, out, error);
+}
+
+AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlConnectionRelease(connection, error);
+}
+
+AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlConnectionRollback(connection, error);
+}
+
+AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key,
+ const char* value, struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlConnectionSetOption(connection, key, value, error);
+}
+
+AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement,
+ struct ArrowArray* values, struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlStatementBind(statement, values, schema, error);
+}
+
+AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
+ struct ArrowArrayStream* stream,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlStatementBindStream(statement, stream, error);
+}
+
+// XXX: cpplint gets confused if declared as struct ArrowSchema*
+AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement,
+ ArrowSchema* schema,
+ struct AdbcPartitions* partitions,
+ int64_t* rows_affected,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlStatementExecutePartitions(
+ statement, schema, partitions, rows_affected, error);
+}
+
+AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
+ struct ArrowArrayStream* out,
+ int64_t* rows_affected,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlStatementExecuteQuery(statement, out, rows_affected,
+ error);
+}
+
+AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
+ struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlStatementGetParameterSchema(statement, schema,
+ error);
+}
+
+AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection,
+ struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlStatementNew(connection, statement, error);
+}
+
+AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlStatementPrepare(statement, error);
+}
+
+AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlStatementRelease(statement, error);
+}
+
+AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key,
+ const char* value, struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlStatementSetOption(statement, key, value, error);
+}
+
+AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
+ const char* query, struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlStatementSetSqlQuery(statement, query, error);
+}
+
+AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement,
+ const uint8_t* plan, size_t length,
+ struct AdbcError* error) {
+ return arrow::flight::sql::FlightSqlStatementSetSubstraitPlan(statement, plan, length,
+ error);
+}
+
+extern "C" {
+ARROW_FLIGHT_SQL_EXPORT
+AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* error) {
+ if (version != ADBC_VERSION_1_0_0) return ADBC_STATUS_NOT_IMPLEMENTED;
+
+ auto* driver = reinterpret_cast(raw_driver);
+ std::memset(driver, 0, sizeof(*driver));
+
+ driver->DatabaseInit = arrow::flight::sql::FlightSqlDatabaseInit;
+ driver->DatabaseNew = arrow::flight::sql::FlightSqlDatabaseNew;
+ driver->DatabaseRelease = arrow::flight::sql::FlightSqlDatabaseRelease;
+ driver->DatabaseSetOption = arrow::flight::sql::FlightSqlDatabaseSetOption;
+
+ driver->ConnectionCommit = arrow::flight::sql::FlightSqlConnectionCommit;
+ driver->ConnectionGetInfo = arrow::flight::sql::FlightSqlConnectionGetInfo;
+ driver->ConnectionGetObjects = arrow::flight::sql::FlightSqlConnectionGetObjects;
+ driver->ConnectionGetTableSchema =
+ arrow::flight::sql::FlightSqlConnectionGetTableSchema;
+ driver->ConnectionGetTableTypes = arrow::flight::sql::FlightSqlConnectionGetTableTypes;
+ driver->ConnectionInit = arrow::flight::sql::FlightSqlConnectionInit;
+ driver->ConnectionNew = arrow::flight::sql::FlightSqlConnectionNew;
+ driver->ConnectionReadPartition = arrow::flight::sql::FlightSqlConnectionReadPartition;
+ driver->ConnectionRelease = arrow::flight::sql::FlightSqlConnectionRelease;
+ driver->ConnectionRollback = arrow::flight::sql::FlightSqlConnectionRollback;
+ driver->ConnectionSetOption = arrow::flight::sql::FlightSqlConnectionSetOption;
+
+ driver->StatementBind = arrow::flight::sql::FlightSqlStatementBind;
+ driver->StatementBindStream = arrow::flight::sql::FlightSqlStatementBindStream;
+ driver->StatementExecutePartitions =
+ arrow::flight::sql::FlightSqlStatementExecutePartitions;
+ driver->StatementExecuteQuery = arrow::flight::sql::FlightSqlStatementExecuteQuery;
+ driver->StatementGetParameterSchema =
+ arrow::flight::sql::FlightSqlStatementGetParameterSchema;
+ driver->StatementNew = arrow::flight::sql::FlightSqlStatementNew;
+ driver->StatementPrepare = arrow::flight::sql::FlightSqlStatementPrepare;
+ driver->StatementRelease = arrow::flight::sql::FlightSqlStatementRelease;
+ driver->StatementSetOption = arrow::flight::sql::FlightSqlStatementSetOption;
+ driver->StatementSetSqlQuery = arrow::flight::sql::FlightSqlStatementSetSqlQuery;
+ driver->StatementSetSubstraitPlan =
+ arrow::flight::sql::FlightSqlStatementSetSubstraitPlan;
+
+ return ADBC_STATUS_OK;
+}
+}
+
+namespace arrow::flight::sql {
+
+AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* error) {
+ return ::AdbcDriverInit(version, raw_driver, error);
+}
+
+} // namespace arrow::flight::sql
diff --git a/cpp/src/arrow/flight/sql/adbc_driver_internal.cc b/cpp/src/arrow/flight/sql/adbc_driver_internal.cc
new file mode 100644
index 00000000000..1b0c893531b
--- /dev/null
+++ b/cpp/src/arrow/flight/sql/adbc_driver_internal.cc
@@ -0,0 +1,191 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/sql/adbc_driver_internal.h"
+
+#include "arrow/c/bridge.h"
+
+namespace arrow::flight::sql {
+
+AdbcStatusCode ArrowToAdbcStatusCode(const Status& status) {
+ if (auto detail = FlightStatusDetail::UnwrapStatus(status)) {
+ switch (detail->code()) {
+ case FlightStatusCode::Internal:
+ return ADBC_STATUS_INTERNAL;
+ case FlightStatusCode::TimedOut:
+ return ADBC_STATUS_TIMEOUT;
+ case FlightStatusCode::Cancelled:
+ return ADBC_STATUS_CANCELLED;
+ case FlightStatusCode::Unauthenticated:
+ return ADBC_STATUS_UNAUTHENTICATED;
+ case FlightStatusCode::Unauthorized:
+ return ADBC_STATUS_UNAUTHORIZED;
+ case FlightStatusCode::Unavailable:
+ return ADBC_STATUS_IO;
+ case FlightStatusCode::Failed:
+ return ADBC_STATUS_INTERNAL;
+ default:
+ break;
+ }
+ }
+ switch (status.code()) {
+ case StatusCode::OK:
+ return ADBC_STATUS_OK;
+ case StatusCode::OutOfMemory:
+ return ADBC_STATUS_INTERNAL;
+ case StatusCode::KeyError:
+ return ADBC_STATUS_NOT_FOUND;
+ case StatusCode::TypeError:
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ case StatusCode::Invalid:
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ case StatusCode::IOError:
+ return ADBC_STATUS_IO;
+ case StatusCode::CapacityError:
+ return ADBC_STATUS_INTERNAL;
+ case StatusCode::IndexError:
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ case StatusCode::Cancelled:
+ return ADBC_STATUS_CANCELLED;
+ case StatusCode::UnknownError:
+ return ADBC_STATUS_UNKNOWN;
+ case StatusCode::NotImplemented:
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ case StatusCode::SerializationError:
+ case StatusCode::RError:
+ case StatusCode::CodeGenError:
+ case StatusCode::ExpressionValidationError:
+ case StatusCode::ExecutionError:
+ return ADBC_STATUS_INTERNAL;
+ case StatusCode::AlreadyExists:
+ return ADBC_STATUS_INVALID_STATE;
+ default:
+ break;
+ }
+ return ADBC_STATUS_UNKNOWN;
+}
+
+void ReleaseError(struct AdbcError* error) {
+ if (error->message) {
+ delete[] error->message;
+ error->message = nullptr;
+ }
+}
+
+FlightClientOptions DefaultClientOptions() {
+ FlightClientOptions client_options = FlightClientOptions::Defaults();
+ client_options.middleware.push_back(MakeTracingClientMiddlewareFactory());
+ client_options.generic_options.emplace_back("grpc.primary_user_agent",
+ "ADBC Flight SQL " ARROW_VERSION_STRING);
+ return client_options;
+}
+
+Status IndexDbSchemas(RecordBatchReader* reader, TableIndex* table_index) {
+ // TODO: unit test
+ ReaderIterator db_schemas(*SqlSchema::GetDbSchemasSchema(), reader);
+ ARROW_RETURN_NOT_OK(db_schemas.Init());
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(bool have_data, db_schemas.Next());
+ if (!have_data) break;
+ std::optional catalog(db_schemas.GetNullable(0));
+ std::optional schema(db_schemas.GetNullable(1));
+
+ auto it = table_index->insert({catalog, {}});
+ it.first->second.insert({schema, {0, 0}});
+ }
+ return Status::OK();
+}
+
+Status IndexTables(RecordBatchReader* reader, TableIndex* table_index,
+ std::shared_ptr* table_data) {
+ ReaderIterator tables(*SqlSchema::GetTablesSchemaWithIncludedSchema(), reader);
+ ARROW_RETURN_NOT_OK(tables.Init());
+
+ int64_t start_index = 0;
+ int64_t index = 0;
+ RecordBatchVector batches;
+
+ std::optional catalog_name;
+ std::optional db_schema_name;
+ bool first = true;
+
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(bool have_data, tables.Next());
+ if (!have_data) break;
+
+ if (std::shared_ptr batch = tables.NewBatch(); batch != nullptr) {
+ batches.push_back(std::move(batch));
+ }
+
+ std::optional catalog(tables.GetNullable(0));
+ std::optional schema(tables.GetNullable(1));
+
+ if (first || catalog != catalog_name || schema != db_schema_name) {
+ if (!first) {
+ auto it = table_index->find(catalog_name);
+ if (it != table_index->end()) {
+ it->second.insert_or_assign(db_schema_name, std::make_pair(start_index, index));
+ }
+ }
+ catalog_name = std::move(catalog);
+ db_schema_name = std::move(schema);
+ start_index = index;
+ first = false;
+ }
+
+ index++;
+ }
+ if (!first) {
+ auto it = table_index->find(catalog_name);
+ if (it != table_index->end()) {
+ it->second.insert_or_assign(db_schema_name, std::make_pair(start_index, index));
+ }
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto all_data, Table::FromRecordBatches(
+ SqlSchema::GetTablesSchemaWithIncludedSchema(),
+ std::move(batches)));
+ ARROW_ASSIGN_OR_RAISE(*table_data, all_data->CombineChunksToBatch());
+ return Status::OK();
+}
+
+Status ExportRecordBatches(std::shared_ptr schema, RecordBatchVector batches,
+ struct ArrowArrayStream* stream) {
+ std::shared_ptr reader;
+ ARROW_ASSIGN_OR_RAISE(reader,
+ RecordBatchReader::Make(std::move(batches), std::move(schema)));
+ return arrow::ExportRecordBatchReader(std::move(reader), stream);
+}
+
+bool IsApproximatelyValidIdentifier(std::string_view name) {
+ if (name.empty()) return false;
+ // First character must be a letter or underscore
+ if (!(name[0] == '_' || (name[0] >= 'a' && name[0] <= 'z') ||
+ (name[0] >= 'A' && name[0] <= 'Z'))) {
+ return false;
+ }
+ // Subsequent characters must be underscores, letters, or digits
+ for (size_t i = 1; i < name.size(); i++) {
+ if (!(name[i] == '_' || (name[i] >= 'a' && name[i] <= 'z') ||
+ (name[i] >= 'A' && name[i] <= 'Z') || (name[i] >= '0' && name[i] <= '9'))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace arrow::flight::sql
diff --git a/cpp/src/arrow/flight/sql/adbc_driver_internal.h b/cpp/src/arrow/flight/sql/adbc_driver_internal.h
new file mode 100644
index 00000000000..4eec8389525
--- /dev/null
+++ b/cpp/src/arrow/flight/sql/adbc_driver_internal.h
@@ -0,0 +1,198 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "arrow/array/array_binary.h"
+#include "arrow/flight/client.h"
+#include "arrow/flight/client_tracing_middleware.h"
+#include "arrow/flight/sql/server.h"
+#include "arrow/flight/sql/visibility.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string_builder.h"
+
+// Override the built-in visibility defines with Arrow's
+#define ADBC_EXPORT ARROW_FLIGHT_SQL_EXPORT
+#include "arrow/c/adbc_internal.h" // IWYU pragma: export
+
+namespace arrow::flight::sql {
+
+// Internal utilities for the Flight SQL driver.
+
+#define ADBC_RETURN_NOT_OK(EXPR) \
+ do { \
+ auto _s = (EXPR); \
+ if (_s != ADBC_STATUS_OK) return _s; \
+ } while (false)
+
+#ifdef ARROW_EXTRA_ERROR_CONTEXT
+#define ADBC_ARROW_RETURN_NOT_OK(ERROR, EXPR) \
+ do { \
+ if (Status _s = (EXPR); !_s.ok()) { \
+ _s.AddContextLine(__FILE__, __LINE__, ARROW_STRINGIFY(EXPR)); \
+ SetError(error, _s); \
+ return ::arrow::flight::sql::ArrowToAdbcStatusCode(_s); \
+ } \
+ } while (false)
+#else
+#define ADBC_ARROW_RETURN_NOT_OK(ERROR, EXPR) \
+ do { \
+ if (Status _s = (EXPR); !_s.ok()) { \
+ SetError(error, _s); \
+ return ::arrow::flight::sql::ArrowToAdbcStatusCode(_s); \
+ } \
+ } while (false)
+#endif
+
+ARROW_FLIGHT_SQL_EXPORT
+AdbcStatusCode ArrowToAdbcStatusCode(const Status& status);
+ARROW_FLIGHT_SQL_EXPORT
+void ReleaseError(struct AdbcError* error);
+ARROW_FLIGHT_SQL_EXPORT
+FlightClientOptions DefaultClientOptions();
+
+/// Helper to populate an AdbcError
+template
+void SetError(struct AdbcError* error, Args&&... args) {
+ if (!error) return;
+ std::string message = util::StringBuilder("[Flight SQL] ", std::forward(args)...);
+ if (error->message) {
+ message.reserve(message.size() + 1 + std::strlen(error->message));
+ message.append(1, '\n');
+ message.append(error->message);
+ delete[] error->message;
+ }
+ error->message = new char[message.size() + 1];
+ message.copy(error->message, message.size());
+ error->message[message.size()] = '\0';
+ error->release = ReleaseError;
+}
+
+template ::ArrayType>
+auto GetNotNull(const RecordBatch& batch, int col_index, int64_t row_index) {
+ return arrow::internal::checked_cast(*batch.column(col_index))
+ .GetView(row_index);
+}
+
+template ::ArrayType>
+std::optional>
+GetNullable(const RecordBatch& batch, int col_index, int64_t row_index) {
+ // compiler unfortunately can't infer the return type here
+ const auto& arr =
+ arrow::internal::checked_cast(*batch.column(col_index));
+ if (arr.IsNull(row_index)) return std::nullopt;
+ return arr.GetView(row_index);
+}
+
+/// \brief Helper to iterate over a RecordBatchReader in a rowwise
+/// manner
+class ARROW_FLIGHT_SQL_EXPORT ReaderIterator {
+ public:
+ explicit ReaderIterator(const Schema& schema, RecordBatchReader* reader)
+ : schema_(schema), reader_(reader) {}
+
+ Status Init() {
+ if (!schema_.Equals(*reader_->schema())) {
+ return Status::Invalid("Server sent the wrong schema.\nExpected:", schema_,
+ "\nActual:", *reader_->schema());
+ }
+ return reader_->Next().Value(¤t_);
+ }
+
+ std::shared_ptr NewBatch() const {
+ return current_row_ == 0 ? current_ : nullptr;
+ }
+
+ arrow::Result Next() {
+ if (done_) return false;
+
+ current_row_++;
+ while (current_ && current_row_ >= current_->num_rows()) {
+ ARROW_ASSIGN_OR_RAISE(current_, reader_->Next());
+ if (!current_) {
+ done_ = true;
+ return false;
+ }
+ current_row_ = 0;
+ }
+ return current_ != nullptr;
+ }
+
+ template ::ArrayType>
+ std::invoke_result_t GetNotNull(
+ int col_index) const {
+ return sql::GetNotNull(*current_, col_index, current_row_);
+ }
+
+ template ::ArrayType>
+ std::optional>
+ GetNullable(int col_index) const {
+ return sql::GetNullable(*current_, col_index, current_row_);
+ }
+
+ private:
+ const Schema& schema_;
+ RecordBatchReader* reader_;
+ std::shared_ptr current_;
+ int64_t current_row_ = -1;
+ bool done_ = false;
+};
+
+// {[catalog name]: {[db schema name]: (inclusive_lower_bound, exclusive_upper_bound)}}
+using TableIndex = std::unordered_map<
+ std::optional,
+ std::unordered_map, std::pair>>;
+
+/// Build up an index of database metadata in-memory to help implement GetObjects.
+ARROW_FLIGHT_SQL_EXPORT
+Status IndexDbSchemas(RecordBatchReader* reader, TableIndex* table_index);
+
+/// Build up an index of database metadata in-memory to help implement GetObjects.
+ARROW_FLIGHT_SQL_EXPORT
+Status IndexTables(RecordBatchReader* reader, TableIndex* table_index,
+ std::shared_ptr* table_data);
+
+/// Export batches of data as an ArrayStream.
+ARROW_FLIGHT_SQL_EXPORT
+Status ExportRecordBatches(std::shared_ptr schema, RecordBatchVector batches,
+ struct ArrowArrayStream* stream);
+
+/// Check if a name is (probably) a valid SQL identifier. We don't
+/// know the exact SQL syntax (though we could try to take advantage
+/// of the info in GetSqlInfo), so this function is conservative about
+/// allowed identifiers.
+ARROW_FLIGHT_SQL_EXPORT
+bool IsApproximatelyValidIdentifier(std::string_view name);
+
+} // namespace arrow::flight::sql
diff --git a/cpp/src/arrow/flight/sql/adbc_driver_test.cc b/cpp/src/arrow/flight/sql/adbc_driver_test.cc
new file mode 100644
index 00000000000..1b060ebc2cb
--- /dev/null
+++ b/cpp/src/arrow/flight/sql/adbc_driver_test.cc
@@ -0,0 +1,1057 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include
+#include
+
+#include
+#include
+#include
+
+#include "arrow/c/bridge.h"
+#include "arrow/flight/server.h"
+#include "arrow/flight/server_middleware.h"
+#include "arrow/flight/sql/adbc_driver_internal.h"
+#include "arrow/flight/sql/example/sqlite_server.h"
+#include "arrow/flight/sql/server.h"
+#include "arrow/flight/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/config.h"
+
+#include "adbc_validation/adbc_validation.h"
+#include "adbc_validation/adbc_validation_util.h"
+
+using arrow::internal::checked_cast;
+
+namespace arrow::flight::sql {
+
+using adbc_validation::IsOkStatus;
+
+#define ARROW_ADBC_RETURN_NOT_OK(STATUS, ERROR, EXPR) \
+ do { \
+ if (AdbcStatusCode _s = (EXPR); _s != ADBC_STATUS_OK) { \
+ return Status::STATUS(adbc_validation::StatusCodeToString(_s), ": ", \
+ (ERROR)->message ? (ERROR)->message : "(no message)"); \
+ } \
+ } while (false)
+
+// ------------------------------------------------------------
+// ADBC Test Suite
+
+class SqliteFlightSqlQuirks : public adbc_validation::DriverQuirks {
+ public:
+ AdbcStatusCode SetupDatabase(struct AdbcDatabase* database,
+ struct AdbcError* error) const override {
+ std::string location = location_.ToString();
+ if (auto status = AdbcDatabaseSetOption(
+ database, "arrow.flight.sql.quirks.ingest_type.int64", "INT", error);
+ status != ADBC_STATUS_OK) {
+ return status;
+ }
+ return AdbcDatabaseSetOption(database, "uri", location.c_str(), error);
+ }
+
+ [[nodiscard]] std::string BindParameter(int index) const override { return "?"; }
+ [[nodiscard]] bool supports_concurrent_statements() const override { return true; }
+ [[nodiscard]] bool supports_partitioned_data() const override { return true; }
+
+ void StartServer() {
+ ASSERT_OK_AND_ASSIGN(auto bind_location, Location::ForGrpcTcp("0.0.0.0", 0));
+ arrow::flight::FlightServerOptions options(bind_location);
+ ASSERT_OK_AND_ASSIGN(server_, example::SQLiteFlightSqlServer::Create());
+ ASSERT_OK(server_->Init(options));
+
+ ASSERT_OK_AND_ASSIGN(location_, Location::ForGrpcTcp("localhost", server_->port()));
+ }
+
+ void StopServer() { ASSERT_OK(server_->Shutdown()); }
+
+ private:
+ std::shared_ptr server_;
+ Location location_;
+};
+
+class AdbcSqliteDatabaseTest : public ::testing::Test,
+ public adbc_validation::DatabaseTest {
+ public:
+ [[nodiscard]] const adbc_validation::DriverQuirks* quirks() const override {
+ return &quirks_;
+ }
+ void SetUp() override {
+ ASSERT_NO_FATAL_FAILURE(quirks_.StartServer());
+ ASSERT_NO_FATAL_FAILURE(SetUpTest());
+ }
+ void TearDown() override {
+ ASSERT_NO_FATAL_FAILURE(TearDownTest());
+ ASSERT_NO_FATAL_FAILURE(quirks_.StopServer());
+ }
+
+ protected:
+ SqliteFlightSqlQuirks quirks_;
+};
+ADBCV_TEST_DATABASE(AdbcSqliteDatabaseTest)
+
+class AdbcSqliteConnectionTest : public ::testing::Test,
+ public adbc_validation::ConnectionTest {
+ public:
+ [[nodiscard]] const adbc_validation::DriverQuirks* quirks() const override {
+ return &quirks_;
+ }
+ void SetUp() override {
+ ASSERT_NO_FATAL_FAILURE(quirks_.StartServer());
+ ASSERT_NO_FATAL_FAILURE(SetUpTest());
+ }
+ void TearDown() override {
+ ASSERT_NO_FATAL_FAILURE(TearDownTest());
+ ASSERT_NO_FATAL_FAILURE(quirks_.StopServer());
+ }
+
+#if !defined(ARROW_COMPUTE)
+ void TestMetadataGetObjectsColumns() {
+ GTEST_SKIP() << "Test fails without ARROW_COMPUTE";
+ }
+#endif
+
+ protected:
+ SqliteFlightSqlQuirks quirks_;
+};
+ADBCV_TEST_CONNECTION(AdbcSqliteConnectionTest)
+
+class AdbcSqliteStatementTest : public ::testing::Test,
+ public adbc_validation::StatementTest {
+ public:
+ [[nodiscard]] const adbc_validation::DriverQuirks* quirks() const override {
+ return &quirks_;
+ }
+ void SetUp() override {
+ ASSERT_NO_FATAL_FAILURE(quirks_.StartServer());
+ ASSERT_NO_FATAL_FAILURE(SetUpTest());
+ }
+ void TearDown() override {
+ ASSERT_NO_FATAL_FAILURE(TearDownTest());
+ ASSERT_NO_FATAL_FAILURE(quirks_.StopServer());
+ }
+
+ protected:
+ SqliteFlightSqlQuirks quirks_;
+};
+ADBCV_TEST_STATEMENT(AdbcSqliteStatementTest)
+
+// Ensure partitions can be introspected by clients who know they're
+// using Flight SQL
+TEST_F(AdbcSqliteStatementTest, IntrospectPartitions) {
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error),
+ IsOkStatus(&error));
+
+ struct ArrowSchema c_schema;
+ struct AdbcPartitions partitions;
+ ASSERT_THAT(AdbcStatementExecutePartitions(&statement, &c_schema, &partitions,
+ /*rows_affected=*/nullptr, &error),
+ IsOkStatus(&error));
+
+ ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
+
+ ASSERT_GT(partitions.num_partitions, 0);
+ for (size_t i = 0; i < partitions.num_partitions; i++) {
+ ASSERT_OK_AND_ASSIGN(auto info,
+ FlightInfo::Deserialize(std::string_view(
+ reinterpret_cast(partitions.partitions[i]),
+ partitions.partition_lengths[i])));
+ ipc::DictionaryMemo memo;
+ ASSERT_OK_AND_ASSIGN(auto info_schema, info->GetSchema(&memo));
+ ASSERT_NO_FATAL_FAILURE(AssertSchemaEqual(*schema, *info_schema, /*verbose=*/true));
+ ASSERT_EQ(1, info->endpoints().size());
+ }
+
+ partitions.release(&partitions);
+}
+
+constexpr std::initializer_list kInvalidNames = {
+ "???", "", "名前", "[quoted]", "9abc", "foo-bar"};
+
+TEST_F(AdbcSqliteStatementTest, InvalidTableName) {
+ for (const auto& name : kInvalidNames) {
+ ARROW_SCOPED_TRACE(name);
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
+ ASSERT_THAT(
+ AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, name, &error),
+ IsOkStatus(&error));
+ auto schema = arrow::schema({field("ints", int64())});
+ auto batch = RecordBatchFromJSON(schema, R"([[1]])");
+ struct ArrowSchema c_schema = {};
+ struct ArrowArray c_batch = {};
+ ASSERT_OK(ExportSchema(*schema, &c_schema));
+ ASSERT_OK(ExportRecordBatch(*batch, &c_batch));
+ ASSERT_THAT(AdbcStatementBind(&statement, &c_batch, &c_schema, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, /*out=*/nullptr,
+ /*rows_affected=*/nullptr, &error),
+ ::testing::Not(IsOkStatus(&error)));
+ ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
+ }
+}
+
+TEST_F(AdbcSqliteStatementTest, InvalidColumnName) {
+ for (const auto& name : kInvalidNames) {
+ ARROW_SCOPED_TRACE(name);
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE,
+ "foobar", &error),
+ IsOkStatus(&error));
+ auto schema = arrow::schema({field(name, int64())});
+ auto batch = RecordBatchFromJSON(schema, R"([[1]])");
+ struct ArrowSchema c_schema = {};
+ struct ArrowArray c_batch = {};
+ ASSERT_OK(ExportSchema(*schema, &c_schema));
+ ASSERT_OK(ExportRecordBatch(*batch, &c_batch));
+ ASSERT_THAT(AdbcStatementBind(&statement, &c_batch, &c_schema, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, /*out=*/nullptr,
+ /*rows_affected=*/nullptr, &error),
+ ::testing::Not(IsOkStatus(&error)));
+ ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
+ }
+}
+
+TEST(FlightSqlInternals, NameSanitizer) {
+ for (const auto& name : kInvalidNames) {
+ ASSERT_FALSE(IsApproximatelyValidIdentifier(name)) << name;
+ }
+ for (const auto& name : {"_foo", "a9", "foo_bar", "ABCD"}) {
+ ASSERT_TRUE(IsApproximatelyValidIdentifier(name)) << name;
+ }
+}
+
+class AdbcTimeoutTestServer : public FlightSqlServerBase {
+ arrow::Result BeginTransaction(
+ const ServerCallContext& context,
+ const ActionBeginTransactionRequest& request) override {
+ while (!context.is_cancelled()) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+ return Status::NotImplemented("NYI");
+ }
+
+ arrow::Result> DoGetStatement(
+ const ServerCallContext& context, const StatementQueryTicket& command) override {
+ while (!context.is_cancelled()) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+ return Status::NotImplemented("NYI");
+ }
+
+ arrow::Result DoPutCommandStatementUpdate(
+ const ServerCallContext& context, const StatementUpdate& command) override {
+ if (command.query == "timeout") {
+ while (!context.is_cancelled()) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+ }
+ return Status::NotImplemented("NYI");
+ }
+
+ arrow::Result> GetFlightInfoStatement(
+ const ServerCallContext& context, const StatementQuery& command,
+ const FlightDescriptor& descriptor) override {
+ if (command.query == "timeout") {
+ while (!context.is_cancelled()) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+ } else if (command.query == "fetch") {
+ ARROW_ASSIGN_OR_RAISE(std::string ticket, CreateStatementQueryTicket("fetch"));
+ ARROW_ASSIGN_OR_RAISE(
+ auto info, FlightInfo::Make(Schema({}), descriptor,
+ {
+ FlightEndpoint{Ticket{std::move(ticket)}, {}},
+ },
+ /*total_records=*/-1, /*total_bytes=*/-1));
+ return std::make_unique(std::move(info));
+ }
+ return Status::NotImplemented("NYI");
+ }
+};
+
+template
+class AdbcServerTest : public ::testing::Test {
+ protected:
+ virtual arrow::Result GetLocation(const std::string& host, int port) {
+ return Location::ForGrpcTcp(host, port);
+ }
+ virtual void Configure(FlightServerOptions* options) {}
+ void SetUp() override {
+ ASSERT_OK_AND_ASSIGN(auto bind_location, GetLocation("0.0.0.0", 0));
+ FlightServerOptions options(bind_location);
+ Configure(&options);
+ server_ = std::make_unique();
+ ASSERT_OK(server_->Init(options));
+
+ ASSERT_OK_AND_ASSIGN(auto connect_location,
+ GetLocation("localhost", server_->port()));
+ std::string uri = connect_location.ToString();
+ ASSERT_THAT(AdbcDatabaseNew(&database_, /*error=*/nullptr), IsOkStatus(&error_));
+ ASSERT_THAT(AdbcDatabaseSetOption(&database_, "uri", uri.c_str(),
+ /*error=*/nullptr),
+ IsOkStatus(&error_));
+ ASSERT_THAT(AdbcDatabaseInit(&database_, /*error=*/nullptr), IsOkStatus(&error_));
+
+ ASSERT_THAT(AdbcConnectionNew(&connection_, /*error=*/nullptr), IsOkStatus(&error_));
+ ASSERT_THAT(AdbcConnectionInit(&connection_, &database_, /*error=*/nullptr),
+ IsOkStatus(&error_));
+ error_ = {};
+ }
+ void TearDown() override {
+ if (error_.release) error_.release(&error_);
+ ASSERT_THAT(AdbcConnectionRelease(&connection_, /*error=*/nullptr),
+ IsOkStatus(&error_));
+ ASSERT_THAT(AdbcDatabaseRelease(&database_, /*error=*/nullptr), IsOkStatus(&error_));
+ ASSERT_OK(server_->Shutdown());
+ }
+
+ std::unique_ptr server_;
+ AdbcDatabase database_;
+ AdbcConnection connection_;
+ AdbcError error_;
+};
+
+class AdbcTimeoutTest : public AdbcServerTest {};
+
+TEST_F(AdbcTimeoutTest, InvalidValues) {
+ for (const auto& key : {
+ "arrow.flight.sql.rpc.timeout_seconds.fetch",
+ "arrow.flight.sql.rpc.timeout_seconds.query",
+ "arrow.flight.sql.rpc.timeout_seconds.update",
+ }) {
+ for (const auto& value : {"1.1f", "asdf", "inf", "NaN", "-1"}) {
+ ARROW_SCOPED_TRACE("key=", key, " value=", value);
+ ASSERT_EQ(ADBC_STATUS_INVALID_ARGUMENT,
+ AdbcConnectionSetOption(&connection_, key, value, &error_));
+ ASSERT_THAT(error_.message, ::testing::HasSubstr("Invalid timeout option value"));
+ }
+ }
+}
+
+TEST_F(AdbcTimeoutTest, RemoveTimeout) {
+ for (const auto& key : {
+ "arrow.flight.sql.rpc.timeout_seconds.fetch",
+ "arrow.flight.sql.rpc.timeout_seconds.query",
+ "arrow.flight.sql.rpc.timeout_seconds.update",
+ }) {
+ ARROW_SCOPED_TRACE("key=", key);
+ ASSERT_THAT(AdbcConnectionSetOption(&connection_, key, "1.0", &error_),
+ IsOkStatus(&error_));
+ ASSERT_THAT(AdbcConnectionSetOption(&connection_, key, "0", &error_),
+ IsOkStatus(&error_));
+ }
+}
+
+TEST_F(AdbcTimeoutTest, DoActionTimeout) {
+ AdbcStatusCode status;
+
+ status = AdbcConnectionSetOption(
+ &connection_, "arrow.flight.sql.rpc.timeout_seconds.update", "0.1", &error_);
+ ASSERT_THAT(status, IsOkStatus(&error_));
+
+ ASSERT_THAT(AdbcConnectionSetOption(&connection_, ADBC_CONNECTION_OPTION_AUTOCOMMIT,
+ ADBC_OPTION_VALUE_DISABLED, &error_),
+ ::testing::Not(IsOkStatus(&error_)));
+ ASSERT_THAT(error_.message, ::testing::Not(::testing::HasSubstr("NYI")));
+}
+
+TEST_F(AdbcTimeoutTest, DoGetTimeout) {
+ struct AdbcStatement stmt = {};
+ struct ArrowArrayStream out = {};
+
+ ASSERT_THAT(
+ AdbcConnectionSetOption(&connection_, "arrow.flight.sql.rpc.timeout_seconds.fetch",
+ "0.1", &error_),
+ IsOkStatus(&error_));
+
+ ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_));
+ ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "fetch", &error_), IsOkStatus(&error_));
+ Status st = ([&]() -> Status {
+ auto status =
+ AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_);
+ if (status != ADBC_STATUS_OK) return Status::Invalid(error_.message);
+
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr reader,
+ arrow::ImportRecordBatchReader(&out));
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr rb, reader->Next());
+ if (!rb) break;
+ }
+ return Status::OK();
+ })();
+ ASSERT_NOT_OK(st);
+ ASSERT_THAT(st.ToString(), ::testing::Not(::testing::HasSubstr("NYI")));
+
+ ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_));
+}
+
+TEST_F(AdbcTimeoutTest, DoPutTimeout) {
+ struct AdbcStatement stmt = {};
+
+ ASSERT_THAT(
+ AdbcConnectionSetOption(&connection_, "arrow.flight.sql.rpc.timeout_seconds.update",
+ "0.1", &error_),
+ IsOkStatus(&error_));
+
+ ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_));
+ ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "timeout", &error_), IsOkStatus(&error_));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, /*out=*/nullptr, /*rows_affected=*/nullptr,
+ &error_),
+ ::testing::Not(IsOkStatus(&error_)));
+ ASSERT_THAT(error_.message, ::testing::Not(::testing::HasSubstr("NYI")));
+
+ ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_));
+}
+
+TEST_F(AdbcTimeoutTest, GetFlightInfoTimeout) {
+ struct AdbcStatement stmt = {};
+ struct ArrowArrayStream out = {};
+
+ ASSERT_THAT(
+ AdbcConnectionSetOption(&connection_, "arrow.flight.sql.rpc.timeout_seconds.query",
+ "0.1", &error_),
+ IsOkStatus(&error_))
+ << error_.message;
+
+ ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_));
+ ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "timeout", &error_), IsOkStatus(&error_));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_),
+ ::testing::Not(IsOkStatus(&error_)));
+ ASSERT_THAT(error_.message, ::testing::Not(::testing::HasSubstr("NYI")));
+
+ if (out.release) out.release(&out);
+ ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_));
+}
+
+/// Ensure options to add custom headers make it through.
+class HeaderServerMiddlewareFactory : public ServerMiddlewareFactory {
+ public:
+ Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ std::shared_ptr* middleware) override {
+ for (const auto& header : incoming_headers) {
+ recorded_headers_.emplace_back(header.first, header.second);
+ }
+ return Status::OK();
+ }
+
+ std::vector> recorded_headers_;
+};
+
+class AdbcHeaderTest : public AdbcServerTest {
+ protected:
+ void Configure(FlightServerOptions* options) override {
+ headers_ = std::make_shared();
+ options->middleware.emplace_back("headers", headers_);
+ }
+ std::shared_ptr headers_;
+};
+
+TEST_F(AdbcHeaderTest, Connection) {
+ struct AdbcStatement stmt = {};
+ struct ArrowArrayStream out = {};
+
+ ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_));
+
+ ASSERT_THAT(
+ AdbcConnectionSetOption(&connection_, "arrow.flight.sql.rpc.call_header.x-span-id",
+ "my span id", &error_),
+ IsOkStatus(&error_));
+ ASSERT_THAT(AdbcConnectionSetOption(&connection_,
+ "arrow.flight.sql.rpc.call_header.x-user-agent",
+ "Flight SQL ADBC", &error_),
+ IsOkStatus(&error_))
+ << error_.message;
+
+ ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "timeout", &error_), IsOkStatus(&error_));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_),
+ ::testing::Not(IsOkStatus(&error_)));
+
+ ASSERT_THAT(AdbcConnectionSetOption(&connection_, ADBC_CONNECTION_OPTION_AUTOCOMMIT,
+ ADBC_OPTION_VALUE_DISABLED, &error_),
+ ::testing::Not(IsOkStatus(&error_)));
+ ASSERT_THAT(headers_->recorded_headers_,
+ ::testing::Contains(std::make_pair("x-span-id", "my span id")));
+ ASSERT_THAT(headers_->recorded_headers_,
+ ::testing::Contains(std::make_pair("x-user-agent", "Flight SQL ADBC")));
+ headers_->recorded_headers_.clear();
+
+ if (out.release) out.release(&out);
+ ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_));
+}
+
+TEST_F(AdbcHeaderTest, Database) {
+ struct AdbcStatement stmt = {};
+ struct ArrowArrayStream out = {};
+
+ ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_));
+
+ ASSERT_THAT(
+ AdbcDatabaseSetOption(&database_, "arrow.flight.sql.rpc.call_header.x-span-id",
+ "my span id", &error_),
+ IsOkStatus(&error_));
+ ASSERT_THAT(AdbcConnectionSetOption(&connection_,
+ "arrow.flight.sql.rpc.call_header.x-user-agent",
+ "Flight SQL ADBC", &error_),
+ IsOkStatus(&error_));
+
+ ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "timeout", &error_), IsOkStatus(&error_));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_),
+ ::testing::Not(IsOkStatus(&error_)));
+
+ ASSERT_THAT(headers_->recorded_headers_,
+ ::testing::Contains(std::make_pair("x-span-id", "my span id")));
+ ASSERT_THAT(headers_->recorded_headers_,
+ ::testing::Contains(std::make_pair("x-user-agent", "Flight SQL ADBC")));
+ headers_->recorded_headers_.clear();
+
+ if (out.release) out.release(&out);
+ ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_));
+}
+
+TEST_F(AdbcHeaderTest, Statement) {
+ struct AdbcStatement stmt = {};
+ struct ArrowArrayStream out = {};
+
+ ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_));
+ ASSERT_THAT(AdbcStatementSetOption(&stmt, "arrow.flight.sql.rpc.call_header.x-span-id",
+ "my span id", &error_),
+ IsOkStatus(&error_));
+ ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "timeout", &error_), IsOkStatus(&error_));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_),
+ ::testing::Not(IsOkStatus(&error_)));
+ ASSERT_THAT(headers_->recorded_headers_,
+ ::testing::Contains(std::make_pair("x-span-id", "my span id")));
+ headers_->recorded_headers_.clear();
+
+ // Set header to NULL to erase it
+ ASSERT_THAT(AdbcStatementSetOption(&stmt, "arrow.flight.sql.rpc.call_header.x-span-id",
+ nullptr, &error_),
+ IsOkStatus(&error_));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_),
+ ::testing::Not(IsOkStatus(&error_)));
+ ASSERT_THAT(
+ headers_->recorded_headers_,
+ ::testing::Not(::testing::Contains(std::make_pair("x-span-id", "my span id"))));
+ headers_->recorded_headers_.clear();
+
+ // Connection headers are inherited
+ ASSERT_THAT(
+ AdbcConnectionSetOption(&connection_, "arrow.flight.sql.rpc.call_header.x-span-id",
+ "my span id", &error_),
+ IsOkStatus(&error_));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_),
+ ::testing::Not(IsOkStatus(&error_)));
+ ASSERT_THAT(headers_->recorded_headers_,
+ ::testing::Contains(std::make_pair("x-span-id", "my span id")));
+ headers_->recorded_headers_.clear();
+
+ if (out.release) out.release(&out);
+ ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_));
+}
+
+TEST_F(AdbcHeaderTest, Combined) {
+ struct AdbcStatement stmt = {};
+ struct ArrowArrayStream out = {};
+
+ ASSERT_THAT(
+ AdbcDatabaseSetOption(&database_, "arrow.flight.sql.rpc.call_header.x-header-one",
+ "value 1", &error_),
+ IsOkStatus(&error_));
+
+ ASSERT_THAT(AdbcConnectionSetOption(&connection_,
+ "arrow.flight.sql.rpc.call_header.x-header-two",
+ "value 2", &error_),
+ IsOkStatus(&error_));
+
+ ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_));
+ ASSERT_THAT(
+ AdbcStatementSetOption(&stmt, "arrow.flight.sql.rpc.call_header.x-header-three",
+ "value 3", &error_),
+ IsOkStatus(&error_));
+
+ ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "timeout", &error_), IsOkStatus(&error_));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_),
+ ::testing::Not(IsOkStatus(&error_)));
+
+ ASSERT_THAT(headers_->recorded_headers_,
+ ::testing::Contains(std::make_pair("x-header-one", "value 1")));
+ ASSERT_THAT(headers_->recorded_headers_,
+ ::testing::Contains(std::make_pair("x-header-two", "value 2")));
+ ASSERT_THAT(headers_->recorded_headers_,
+ ::testing::Contains(std::make_pair("x-header-three", "value 3")));
+
+ if (out.release) out.release(&out);
+ ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_));
+}
+
+class SubstraitTestServer : public FlightSqlServerBase {
+ arrow::Result> DoGetStatement(
+ const ServerCallContext& context, const StatementQueryTicket& command) override {
+ if (command.statement_handle != "expected plan") {
+ return Status::Invalid("invalid plan");
+ }
+ ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({}, arrow::schema({})));
+ return std::make_unique(reader);
+ }
+
+ arrow::Result DoPutCommandSubstraitPlan(
+ const ServerCallContext& context, const StatementSubstraitPlan& command) override {
+ if (command.plan.plan != "expected plan") return Status::Invalid("invalid plan");
+ return 42;
+ }
+
+ arrow::Result> GetFlightInfoSubstraitPlan(
+ const ServerCallContext& context, const StatementSubstraitPlan& command,
+ const FlightDescriptor& descriptor) override {
+ if (command.plan.plan != "expected plan") {
+ return Status::Invalid("invalid plan");
+ }
+ ARROW_ASSIGN_OR_RAISE(std::string ticket,
+ CreateStatementQueryTicket(command.plan.plan));
+ std::vector endpoints = {
+ FlightEndpoint{Ticket{std::move(ticket)}, {}}};
+ ARROW_ASSIGN_OR_RAISE(auto info,
+ FlightInfo::Make(Schema({}), descriptor, std::move(endpoints),
+ /*total_records=*/-1, /*total_bytes=*/-1));
+ return std::make_unique(std::move(info));
+ }
+};
+
+class AdbcSubstraitTest : public AdbcServerTest {};
+
+TEST_F(AdbcSubstraitTest, DoGet) {
+ struct AdbcStatement stmt = {};
+ struct ArrowArrayStream out = {};
+
+ std::string plan = "expected plan";
+
+ ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_));
+ ASSERT_THAT(
+ AdbcStatementSetSubstraitPlan(&stmt, reinterpret_cast(plan.data()),
+ plan.size(), &error_),
+ IsOkStatus(&error_));
+ int64_t rows_affected = 0;
+ ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, &rows_affected, &error_),
+ IsOkStatus(&error_));
+ ASSERT_NE(nullptr, out.release);
+
+ out.release(&out);
+ ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_));
+}
+
+TEST_F(AdbcSubstraitTest, DoPut) {
+ struct AdbcStatement stmt = {};
+
+ std::string plan = "expected plan";
+
+ ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_));
+ ASSERT_THAT(
+ AdbcStatementSetSubstraitPlan(&stmt, reinterpret_cast(plan.data()),
+ plan.size(), &error_),
+ IsOkStatus(&error_));
+ int64_t rows_affected = 0;
+ ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, /*out=*/nullptr, &rows_affected, &error_),
+ IsOkStatus(&error_));
+ ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_));
+
+ ASSERT_EQ(42, rows_affected);
+}
+
+class FallbackTestServer : public FlightSqlServerBase {
+ public:
+ arrow::Result> DoGetStatement(
+ const ServerCallContext& context, const StatementQueryTicket& command) override {
+ if (command.statement_handle == "FailedGet" && counter_++ == 0) {
+ // Fail the first time around
+ return Status::IOError("First time fails");
+ }
+ ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({}, arrow::schema({})));
+ return std::make_unique(reader);
+ }
+
+ arrow::Result> GetFlightInfoStatement(
+ const ServerCallContext& context, const StatementQuery& command,
+ const FlightDescriptor& descriptor) override {
+ ARROW_ASSIGN_OR_RAISE(std::string ticket, CreateStatementQueryTicket(command.query));
+ std::vector endpoints;
+ if (command.query == "NoLocation") {
+ endpoints = {FlightEndpoint{Ticket{std::move(ticket)}, {}}};
+ } else if (command.query == "FailedConnection") {
+ ARROW_ASSIGN_OR_RAISE(auto location1,
+ Location::ForGrpcTcp("unreachable", this->port()));
+ ARROW_ASSIGN_OR_RAISE(auto location2,
+ Location::ForGrpcTcp("localhost", this->port()));
+ endpoints = {
+ FlightEndpoint{Ticket{std::move(ticket)},
+ {
+ location1,
+ location2,
+ }},
+ };
+ } else if (command.query == "FailedGet") {
+ ARROW_ASSIGN_OR_RAISE(auto location,
+ Location::ForGrpcTcp("localhost", this->port()));
+ endpoints = {
+ FlightEndpoint{Ticket{std::move(ticket)},
+ {
+ location,
+ location,
+ }},
+ };
+ } else if (command.query == "Failure") {
+ ARROW_ASSIGN_OR_RAISE(auto location,
+ Location::ForGrpcTcp("unreachable", this->port()));
+ endpoints = {
+ FlightEndpoint{Ticket{std::move(ticket)},
+ {
+ location,
+ location,
+ }},
+ };
+ } else {
+ return Status::Invalid(command.query);
+ }
+ ARROW_ASSIGN_OR_RAISE(auto info,
+ FlightInfo::Make(Schema({}), descriptor, std::move(endpoints),
+ /*total_records=*/-1, /*total_bytes=*/-1));
+ return std::make_unique(std::move(info));
+ }
+
+ private:
+ int32_t counter_ = 0;
+};
+
+class AdbcFallbackTest : public AdbcServerTest {
+ protected:
+ Status DoQuery(const std::string& query) {
+ struct AdbcStatement stmt = {};
+ ARROW_ADBC_RETURN_NOT_OK(UnknownError, &error_,
+ AdbcStatementNew(&connection_, &stmt, &error_));
+ ARROW_ADBC_RETURN_NOT_OK(UnknownError, &error_,
+ AdbcStatementSetSqlQuery(&stmt, query.data(), &error_));
+
+ struct ArrowArrayStream out;
+ AdbcStatusCode status =
+ AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_);
+
+ ARROW_ADBC_RETURN_NOT_OK(UnknownError, &error_, AdbcStatementRelease(&stmt, &error_));
+ ARROW_ADBC_RETURN_NOT_OK(UnknownError, &error_, status);
+ if (out.release) out.release(&out);
+ return Status::OK();
+ }
+
+ arrow::Result>> GetPartitions(
+ const std::string& query) {
+ struct AdbcStatement stmt = {};
+ ARROW_ADBC_RETURN_NOT_OK(UnknownError, &error_,
+ AdbcStatementNew(&connection_, &stmt, &error_));
+ ARROW_ADBC_RETURN_NOT_OK(UnknownError, &error_,
+ AdbcStatementSetSqlQuery(&stmt, query.data(), &error_));
+
+ struct ArrowSchema c_schema;
+ struct AdbcPartitions partitions;
+ ARROW_ADBC_RETURN_NOT_OK(
+ UnknownError, &error_,
+ AdbcStatementExecutePartitions(&stmt, &c_schema, &partitions,
+ /*rows_affected=*/nullptr, &error_));
+ EXPECT_GT(partitions.num_partitions, 0);
+ c_schema.release(&c_schema);
+
+ std::vector> result;
+ for (size_t i = 0; i < partitions.num_partitions; i++) {
+ EXPECT_OK_AND_ASSIGN(auto info,
+ FlightInfo::Deserialize(std::string_view(
+ reinterpret_cast(partitions.partitions[i]),
+ partitions.partition_lengths[i])));
+ result.emplace_back(std::move(info));
+ }
+
+ if (c_schema.release) c_schema.release(&c_schema);
+ partitions.release(&partitions);
+
+ ARROW_ADBC_RETURN_NOT_OK(UnknownError, &error_, AdbcStatementRelease(&stmt, &error_));
+ return result;
+ }
+};
+
+TEST_F(AdbcFallbackTest, NoLocation) {
+ ASSERT_OK(DoQuery("NoLocation"));
+ ASSERT_OK_AND_ASSIGN(auto partitions, GetPartitions("NoLocation"));
+ ASSERT_EQ(1, partitions.size());
+ ASSERT_EQ(1, partitions[0]->endpoints().size());
+ ASSERT_TRUE(partitions[0]->endpoints()[0].locations.empty());
+}
+
+TEST_F(AdbcFallbackTest, FailedConnection) {
+ ASSERT_OK(DoQuery("FailedConnection"));
+ ASSERT_OK_AND_ASSIGN(auto partitions, GetPartitions("FailedConnection"));
+ ASSERT_EQ(1, partitions.size());
+ ASSERT_EQ(1, partitions[0]->endpoints().size());
+ ASSERT_EQ(2, partitions[0]->endpoints()[0].locations.size());
+}
+
+TEST_F(AdbcFallbackTest, FailedGet) {
+ ASSERT_OK(DoQuery("FailedGet"));
+ ASSERT_OK_AND_ASSIGN(auto partitions, GetPartitions("FailedGet"));
+ ASSERT_EQ(1, partitions.size());
+ ASSERT_EQ(1, partitions[0]->endpoints().size());
+ ASSERT_EQ(2, partitions[0]->endpoints()[0].locations.size());
+}
+
+TEST_F(AdbcFallbackTest, Failure) {
+ ASSERT_NOT_OK(DoQuery("Failure"));
+ ASSERT_OK_AND_ASSIGN(auto partitions, GetPartitions("Failure"));
+ ASSERT_EQ(1, partitions.size());
+ ASSERT_EQ(1, partitions[0]->endpoints().size());
+ ASSERT_EQ(2, partitions[0]->endpoints()[0].locations.size());
+}
+
+class NoOpTestServer : public FlightSqlServerBase {
+ public:
+ arrow::Result> DoGetStatement(
+ const ServerCallContext& context, const StatementQueryTicket& command) override {
+ ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({}, arrow::schema({})));
+ return std::make_unique(reader);
+ }
+
+ arrow::Result> GetFlightInfoStatement(
+ const ServerCallContext& context, const StatementQuery& command,
+ const FlightDescriptor& descriptor) override {
+ ARROW_ASSIGN_OR_RAISE(std::string ticket, CreateStatementQueryTicket(command.query));
+ ARROW_ASSIGN_OR_RAISE(auto info,
+ FlightInfo::Make(Schema({}), descriptor, {},
+ /*total_records=*/-1, /*total_bytes=*/-1));
+ return std::make_unique(std::move(info));
+ }
+};
+
+class AdbcTlsTest : public AdbcServerTest {
+ protected:
+ arrow::Result