From fef4d9dade9a3cf7cb7a08cd6bf1a9ce3c940ce7 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 28 Sep 2022 10:32:44 -0400
Subject: [PATCH 1/4] ARROW-17867: [C++][FlightRPC] Expose bulk parameter
binding in Flight SQL
---
cpp/src/arrow/flight/sql/CMakeLists.txt | 8 -
cpp/src/arrow/flight/sql/client.cc | 80 +--
cpp/src/arrow/flight/sql/client.h | 11 +-
cpp/src/arrow/flight/sql/client_test.cc | 530 ------------------
.../arrow/flight/sql/example/sqlite_server.cc | 156 +++---
.../arrow/flight/sql/example/sqlite_server.h | 5 +-
.../flight/sql/example/sqlite_statement.cc | 91 ++-
.../flight/sql/example/sqlite_statement.h | 12 +-
.../example/sqlite_statement_batch_reader.cc | 142 +++--
.../example/sqlite_statement_batch_reader.h | 4 +
.../sqlite_tables_schema_batch_reader.cc | 18 +-
cpp/src/arrow/flight/sql/server.cc | 94 ++--
cpp/src/arrow/flight/sql/server.h | 22 +-
cpp/src/arrow/flight/sql/server_test.cc | 174 +++---
14 files changed, 473 insertions(+), 874 deletions(-)
delete mode 100644 cpp/src/arrow/flight/sql/client_test.cc
diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt
index d5134f3e8b3..628b02b9d28 100644
--- a/cpp/src/arrow/flight/sql/CMakeLists.txt
+++ b/cpp/src/arrow/flight/sql/CMakeLists.txt
@@ -98,14 +98,6 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES)
set(ARROW_FLIGHT_SQL_TEST_LIBS ${SQLite3_LIBRARIES})
set(ARROW_FLIGHT_SQL_ACERO_SRCS example/acero_server.cc)
- if(NOT MSVC AND NOT MINGW)
- # ARROW-16902: getting Protobuf generated code to have all the
- # proper dllexport/dllimport declarations is difficult, since
- # protoc does not insert them everywhere needed to satisfy both
- # MinGW and MSVC, and the Protobuf team recommends against it
- list(APPEND ARROW_FLIGHT_SQL_TEST_SRCS client_test.cc)
- endif()
-
if(ARROW_COMPUTE
AND ARROW_PARQUET
AND ARROW_SUBSTRAIT)
diff --git a/cpp/src/arrow/flight/sql/client.cc b/cpp/src/arrow/flight/sql/client.cc
index 521cf9e8cd6..c4502b188f7 100644
--- a/cpp/src/arrow/flight/sql/client.cc
+++ b/cpp/src/arrow/flight/sql/client.cc
@@ -16,6 +16,7 @@
// under the License.
// Platform-specific defines
+#include "arrow/flight/client.h"
#include "arrow/flight/platform.h"
#include "arrow/flight/sql/client.h"
@@ -208,11 +209,12 @@ arrow::Result FlightSqlClient::ExecuteUpdate(const FlightCallOptions& o
std::unique_ptr reader;
ARROW_RETURN_NOT_OK(DoPut(options, descriptor, arrow::schema({}), &writer, &reader));
-
std::shared_ptr metadata;
ARROW_RETURN_NOT_OK(reader->ReadMetadata(&metadata));
ARROW_RETURN_NOT_OK(writer->Close());
+ if (!metadata) return Status::IOError("Server did not send a response");
+
flight_sql_pb::DoPutUpdateResult result;
if (!result.ParseFromArray(metadata->data(), static_cast(metadata->size()))) {
return Status::Invalid("Unable to parse DoPutUpdateResult");
@@ -535,6 +537,24 @@ arrow::Result> PreparedStatement::ParseRespon
parameter_schema);
}
+arrow::Result> BindParameters(FlightClient* client,
+ const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
+ RecordBatchReader* params) {
+ ARROW_ASSIGN_OR_RAISE(auto stream,
+ client->DoPut(options, descriptor, params->schema()));
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(auto batch, params->Next());
+ if (!batch) break;
+ ARROW_RETURN_NOT_OK(stream.writer->WriteRecordBatch(*batch));
+ }
+ ARROW_RETURN_NOT_OK(stream.writer->DoneWriting());
+ std::shared_ptr metadata;
+ ARROW_RETURN_NOT_OK(stream.reader->ReadMetadata(&metadata));
+ ARROW_RETURN_NOT_OK(stream.writer->Close());
+ return metadata;
+}
+
arrow::Result> PreparedStatement::Execute(
const FlightCallOptions& options) {
if (is_closed_) {
@@ -545,21 +565,11 @@ arrow::Result> PreparedStatement::Execute(
command.set_prepared_statement_handle(handle_);
ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor,
GetFlightDescriptorForCommand(command));
-
- if (parameter_binding_ && parameter_binding_->num_rows() > 0) {
- std::unique_ptr writer;
- std::unique_ptr reader;
- ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, parameter_binding_->schema(),
- &writer, &reader));
-
- ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*parameter_binding_));
- ARROW_RETURN_NOT_OK(writer->DoneWriting());
- // Wait for the server to ack the result
- std::shared_ptr buffer;
- ARROW_RETURN_NOT_OK(reader->ReadMetadata(&buffer));
- ARROW_RETURN_NOT_OK(writer->Close());
+ if (parameter_binding_) {
+ ARROW_ASSIGN_OR_RAISE(auto metadata,
+ BindParameters(client_->impl_.get(), options, descriptor,
+ parameter_binding_.get()));
}
-
ARROW_ASSIGN_OR_RAISE(auto flight_info, client_->GetFlightInfo(options, descriptor));
return std::move(flight_info);
}
@@ -574,26 +584,21 @@ arrow::Result PreparedStatement::ExecuteUpdate(
command.set_prepared_statement_handle(handle_);
ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor,
GetFlightDescriptorForCommand(command));
- std::unique_ptr writer;
- std::unique_ptr reader;
-
- if (parameter_binding_ && parameter_binding_->num_rows() > 0) {
- ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, parameter_binding_->schema(),
- &writer, &reader));
- ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*parameter_binding_));
+ std::shared_ptr metadata;
+ if (parameter_binding_) {
+ ARROW_ASSIGN_OR_RAISE(metadata, BindParameters(client_->impl_.get(), options,
+ descriptor, parameter_binding_.get()));
} else {
const std::shared_ptr schema = arrow::schema({});
- ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, schema, &writer, &reader));
- const ArrayVector columns;
- const auto& record_batch = arrow::RecordBatch::Make(schema, 0, columns);
- ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*record_batch));
+ auto record_batch = arrow::RecordBatch::Make(schema, 0, ArrayVector{});
+ ARROW_ASSIGN_OR_RAISE(auto params,
+ RecordBatchReader::Make({std::move(record_batch)}));
+ ARROW_ASSIGN_OR_RAISE(metadata, BindParameters(client_->impl_.get(), options,
+ descriptor, params.get()));
+ }
+ if (!metadata) {
+ return Status::IOError("Server did not send a response to ", command.GetTypeName());
}
-
- ARROW_RETURN_NOT_OK(writer->DoneWriting());
- std::shared_ptr metadata;
- ARROW_RETURN_NOT_OK(reader->ReadMetadata(&metadata));
- ARROW_RETURN_NOT_OK(writer->Close());
-
flight_sql_pb::DoPutUpdateResult result;
if (!result.ParseFromArray(metadata->data(), static_cast(metadata->size()))) {
return Status::Invalid("Unable to parse DoPutUpdateResult object.");
@@ -603,6 +608,13 @@ arrow::Result PreparedStatement::ExecuteUpdate(
}
Status PreparedStatement::SetParameters(std::shared_ptr parameter_binding) {
+ ARROW_ASSIGN_OR_RAISE(parameter_binding_,
+ RecordBatchReader::Make({std::move(parameter_binding)}));
+ return Status::OK();
+}
+
+Status PreparedStatement::SetParameters(
+ std::shared_ptr parameter_binding) {
parameter_binding_ = std::move(parameter_binding);
return Status::OK();
@@ -610,11 +622,11 @@ Status PreparedStatement::SetParameters(std::shared_ptr parameter_b
bool PreparedStatement::IsClosed() const { return is_closed_; }
-std::shared_ptr PreparedStatement::dataset_schema() const {
+const std::shared_ptr& PreparedStatement::dataset_schema() const {
return dataset_schema_;
}
-std::shared_ptr PreparedStatement::parameter_schema() const {
+const std::shared_ptr& PreparedStatement::parameter_schema() const {
return parameter_schema_;
}
diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h
index db168847ed6..548c1f33a63 100644
--- a/cpp/src/arrow/flight/sql/client.h
+++ b/cpp/src/arrow/flight/sql/client.h
@@ -392,17 +392,18 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement {
/// \brief Retrieve the parameter schema from the query.
/// \return The parameter schema from the query.
- std::shared_ptr parameter_schema() const;
+ [[nodiscard]] const std::shared_ptr& parameter_schema() const;
/// \brief Retrieve the ResultSet schema from the query.
/// \return The ResultSet schema from the query.
- std::shared_ptr dataset_schema() const;
+ [[nodiscard]] const std::shared_ptr& dataset_schema() const;
/// \brief Set a RecordBatch that contains the parameters that will be bound.
- /// \param parameter_binding The parameters that will be bound.
- /// \return Status.
Status SetParameters(std::shared_ptr parameter_binding);
+ /// \brief Set a RecordBatchReader that contains the parameters that will be bound.
+ Status SetParameters(std::shared_ptr parameter_binding);
+
/// \brief Re-request the result set schema from the server (should
/// be identical to dataset_schema).
arrow::Result> GetSchema(
@@ -422,7 +423,7 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement {
std::string handle_;
std::shared_ptr dataset_schema_;
std::shared_ptr parameter_schema_;
- std::shared_ptr parameter_binding_;
+ std::shared_ptr parameter_binding_;
bool is_closed_;
};
diff --git a/cpp/src/arrow/flight/sql/client_test.cc b/cpp/src/arrow/flight/sql/client_test.cc
deleted file mode 100644
index 984bf454816..00000000000
--- a/cpp/src/arrow/flight/sql/client_test.cc
+++ /dev/null
@@ -1,530 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-// Platform-specific defines
-#include "arrow/flight/platform.h"
-
-#include "arrow/flight/client.h"
-
-#include
-#include
-#include
-
-#include
-
-#include "arrow/buffer.h"
-#include "arrow/flight/sql/api.h"
-#include "arrow/flight/sql/protocol_internal.h"
-#include "arrow/testing/gtest_util.h"
-
-namespace pb = arrow::flight::protocol;
-using ::testing::_;
-using ::testing::Ref;
-
-namespace arrow {
-namespace flight {
-namespace sql {
-
-class FlightSqlClientMock : public FlightSqlClient {
- public:
- FlightSqlClientMock() : FlightSqlClient(nullptr) {}
-
- ~FlightSqlClientMock() = default;
-
- MOCK_METHOD(arrow::Result>, GetFlightInfo,
- (const FlightCallOptions&, const FlightDescriptor&));
- MOCK_METHOD(Status, DoGet,
- (const FlightCallOptions& options, const Ticket& ticket,
- std::unique_ptr* stream));
- MOCK_METHOD(Status, DoPut,
- (const FlightCallOptions&, const FlightDescriptor&,
- const std::shared_ptr& schema,
- std::unique_ptr*,
- std::unique_ptr*));
- MOCK_METHOD(Status, DoAction,
- (const FlightCallOptions& options, const Action& action,
- std::unique_ptr* results));
-};
-
-class TestFlightSqlClient : public ::testing::Test {
- protected:
- FlightSqlClientMock sql_client_;
- FlightCallOptions call_options_;
-
- void SetUp() override {}
-
- void TearDown() override {}
-};
-
-class FlightMetadataReaderMock : public FlightMetadataReader {
- public:
- std::shared_ptr* buffer;
-
- explicit FlightMetadataReaderMock(std::shared_ptr* buffer) {
- this->buffer = buffer;
- }
-
- Status ReadMetadata(std::shared_ptr* out) override {
- *out = *buffer;
- return Status::OK();
- }
-};
-
-class FlightStreamWriterMock : public FlightStreamWriter {
- public:
- FlightStreamWriterMock() = default;
-
- Status DoneWriting() override { return Status::OK(); }
-
- Status WriteMetadata(std::shared_ptr app_metadata) override {
- return Status::OK();
- }
-
- Status Begin(const std::shared_ptr& schema,
- const ipc::IpcWriteOptions& options) override {
- return Status::OK();
- }
-
- Status Begin(const std::shared_ptr& schema) override {
- return MetadataRecordBatchWriter::Begin(schema);
- }
-
- ipc::WriteStats stats() const override { return ipc::WriteStats(); }
-
- Status WriteWithMetadata(const RecordBatch& batch,
- std::shared_ptr app_metadata) override {
- return Status::OK();
- }
-
- Status Close() override { return Status::OK(); }
-
- Status WriteRecordBatch(const RecordBatch& batch) override { return Status::OK(); }
-};
-
-FlightDescriptor getDescriptor(google::protobuf::Message& command) {
- google::protobuf::Any any;
- any.PackFrom(command);
-
- const std::string& string = any.SerializeAsString();
- return FlightDescriptor::Command(string);
-}
-
-auto ReturnEmptyFlightInfo = [](const FlightCallOptions& options,
- const FlightDescriptor& descriptor) {
- std::unique_ptr flight_info;
- return flight_info;
-};
-
-TEST_F(TestFlightSqlClient, TestGetCatalogs) {
- pb::sql::CommandGetCatalogs command;
- FlightDescriptor descriptor = getDescriptor(command);
-
- ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo);
- EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor));
-
- ASSERT_OK(sql_client_.GetCatalogs(call_options_));
-}
-
-TEST_F(TestFlightSqlClient, TestGetDbSchemas) {
- std::string schema_filter_pattern = "schema_filter_pattern";
- std::string catalog = "catalog";
-
- pb::sql::CommandGetDbSchemas command;
- command.set_catalog(catalog);
- command.set_db_schema_filter_pattern(schema_filter_pattern);
- FlightDescriptor descriptor = getDescriptor(command);
-
- ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo);
- EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor));
-
- ASSERT_OK(sql_client_.GetDbSchemas(call_options_, &catalog, &schema_filter_pattern));
-}
-
-TEST_F(TestFlightSqlClient, TestGetTables) {
- std::string catalog = "catalog";
- std::string schema_filter_pattern = "schema_filter_pattern";
- std::string table_name_filter_pattern = "table_name_filter_pattern";
- bool include_schema = true;
- std::vector table_types = {"type1", "type2"};
-
- pb::sql::CommandGetTables command;
- command.set_catalog(catalog);
- command.set_db_schema_filter_pattern(schema_filter_pattern);
- command.set_table_name_filter_pattern(table_name_filter_pattern);
- command.set_include_schema(include_schema);
- for (const std::string& table_type : table_types) {
- command.add_table_types(table_type);
- }
- FlightDescriptor descriptor = getDescriptor(command);
-
- ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo);
- EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor));
-
- ASSERT_OK(sql_client_.GetTables(call_options_, &catalog, &schema_filter_pattern,
- &table_name_filter_pattern, include_schema,
- &table_types));
-}
-
-TEST_F(TestFlightSqlClient, TestGetTableTypes) {
- pb::sql::CommandGetTableTypes command;
- FlightDescriptor descriptor = getDescriptor(command);
-
- ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo);
- EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor));
-
- ASSERT_OK(sql_client_.GetTableTypes(call_options_));
-}
-
-TEST_F(TestFlightSqlClient, TestGetTypeInfo) {
- pb::sql::CommandGetXdbcTypeInfo command;
- FlightDescriptor descriptor = getDescriptor(command);
-
- ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo);
- EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor));
-
- ASSERT_OK(sql_client_.GetXdbcTypeInfo(call_options_));
-}
-
-TEST_F(TestFlightSqlClient, TestGetExported) {
- std::string catalog = "catalog";
- std::string schema = "schema";
- std::string table = "table";
-
- pb::sql::CommandGetExportedKeys command;
- command.set_catalog(catalog);
- command.set_db_schema(schema);
- command.set_table(table);
- FlightDescriptor descriptor = getDescriptor(command);
-
- ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo);
- EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor));
-
- TableRef table_ref = {std::make_optional(catalog), std::make_optional(schema), table};
- ASSERT_OK(sql_client_.GetExportedKeys(call_options_, table_ref));
-}
-
-TEST_F(TestFlightSqlClient, TestGetImported) {
- std::string catalog = "catalog";
- std::string schema = "schema";
- std::string table = "table";
-
- pb::sql::CommandGetImportedKeys command;
- command.set_catalog(catalog);
- command.set_db_schema(schema);
- command.set_table(table);
- FlightDescriptor descriptor = getDescriptor(command);
-
- ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo);
- EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor));
-
- TableRef table_ref = {std::make_optional(catalog), std::make_optional(schema), table};
- ASSERT_OK(sql_client_.GetImportedKeys(call_options_, table_ref));
-}
-
-TEST_F(TestFlightSqlClient, TestGetPrimary) {
- std::string catalog = "catalog";
- std::string schema = "schema";
- std::string table = "table";
-
- pb::sql::CommandGetPrimaryKeys command;
- command.set_catalog(catalog);
- command.set_db_schema(schema);
- command.set_table(table);
- FlightDescriptor descriptor = getDescriptor(command);
-
- ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo);
- EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor));
-
- TableRef table_ref = {std::make_optional(catalog), std::make_optional(schema), table};
- ASSERT_OK(sql_client_.GetPrimaryKeys(call_options_, table_ref));
-}
-
-TEST_F(TestFlightSqlClient, TestGetCrossReference) {
- std::string pk_catalog = "pk_catalog";
- std::string pk_schema = "pk_schema";
- std::string pk_table = "pk_table";
- std::string fk_catalog = "fk_catalog";
- std::string fk_schema = "fk_schema";
- std::string fk_table = "fk_table";
-
- pb::sql::CommandGetCrossReference command;
- command.set_pk_catalog(pk_catalog);
- command.set_pk_db_schema(pk_schema);
- command.set_pk_table(pk_table);
- command.set_fk_catalog(fk_catalog);
- command.set_fk_db_schema(fk_schema);
- command.set_fk_table(fk_table);
- FlightDescriptor descriptor = getDescriptor(command);
-
- ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo);
- EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor));
-
- TableRef pk_table_ref = {std::make_optional(pk_catalog), std::make_optional(pk_schema),
- pk_table};
- TableRef fk_table_ref = {std::make_optional(fk_catalog), std::make_optional(fk_schema),
- fk_table};
- ASSERT_OK(sql_client_.GetCrossReference(call_options_, pk_table_ref, fk_table_ref));
-}
-
-TEST_F(TestFlightSqlClient, TestExecute) {
- std::string query = "query";
-
- pb::sql::CommandStatementQuery command;
- command.set_query(query);
- FlightDescriptor descriptor = getDescriptor(command);
-
- ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo);
- EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor));
-
- ASSERT_OK(sql_client_.Execute(call_options_, query));
-}
-
-TEST_F(TestFlightSqlClient, TestPreparedStatementExecute) {
- const std::string query = "query";
-
- ON_CALL(sql_client_, DoAction)
- .WillByDefault([](const FlightCallOptions& options, const Action& action,
- std::unique_ptr* results) {
- google::protobuf::Any command;
-
- pb::sql::ActionCreatePreparedStatementResult prepared_statement_result;
-
- prepared_statement_result.set_prepared_statement_handle("query");
-
- command.PackFrom(prepared_statement_result);
-
- *results = std::unique_ptr(new SimpleResultStream(
- {Result{Buffer::FromString(command.SerializeAsString())}}));
-
- return Status::OK();
- });
-
- EXPECT_CALL(sql_client_, DoAction(_, _, _)).Times(2);
-
- ASSERT_OK_AND_ASSIGN(auto prepared_statement,
- sql_client_.Prepare(call_options_, query));
-
- ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo);
- EXPECT_CALL(sql_client_, GetFlightInfo(_, _));
-
- ASSERT_OK(prepared_statement->Execute());
-}
-
-TEST_F(TestFlightSqlClient, TestPreparedStatementExecuteParameterBinding) {
- const std::string query = "query";
-
- ON_CALL(sql_client_, DoAction)
- .WillByDefault([](const FlightCallOptions& options, const Action& action,
- std::unique_ptr* results) {
- google::protobuf::Any command;
-
- pb::sql::ActionCreatePreparedStatementResult prepared_statement_result;
-
- prepared_statement_result.set_prepared_statement_handle("query");
-
- auto schema = arrow::schema({arrow::field("id", int64())});
-
- std::shared_ptr schema_buffer;
- const arrow::Result>& result =
- arrow::ipc::SerializeSchema(*schema);
-
- ARROW_ASSIGN_OR_RAISE(schema_buffer, result);
-
- prepared_statement_result.set_parameter_schema(schema_buffer->ToString());
-
- command.PackFrom(prepared_statement_result);
-
- *results = std::unique_ptr(new SimpleResultStream(
- {Result{Buffer::FromString(command.SerializeAsString())}}));
-
- return Status::OK();
- });
-
- std::shared_ptr buffer_ptr;
- ON_CALL(sql_client_, DoPut)
- .WillByDefault([&buffer_ptr](const FlightCallOptions& options,
- const FlightDescriptor& descriptor1,
- const std::shared_ptr& schema,
- std::unique_ptr* writer,
- std::unique_ptr* reader) {
- writer->reset(new FlightStreamWriterMock());
- reader->reset(new FlightMetadataReaderMock(&buffer_ptr));
-
- return Status::OK();
- });
-
- EXPECT_CALL(sql_client_, DoAction(_, _, _)).Times(2);
- EXPECT_CALL(sql_client_, DoPut(_, _, _, _, _));
-
- ASSERT_OK_AND_ASSIGN(auto prepared_statement,
- sql_client_.Prepare(call_options_, query));
-
- auto parameter_schema = prepared_statement->parameter_schema();
-
- auto result = RecordBatchFromJSON(parameter_schema, "[[1]]");
- ASSERT_OK(prepared_statement->SetParameters(result));
-
- ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo);
- EXPECT_CALL(sql_client_, GetFlightInfo(_, _));
-
- ASSERT_OK(prepared_statement->Execute());
-}
-
-TEST_F(TestFlightSqlClient, TestExecuteUpdate) {
- std::string query = "query";
-
- pb::sql::CommandStatementUpdate command;
-
- command.set_query(query);
-
- google::protobuf::Any any;
- any.PackFrom(command);
-
- const FlightDescriptor& descriptor = FlightDescriptor::Command(any.SerializeAsString());
-
- pb::sql::DoPutUpdateResult doPutUpdateResult;
- doPutUpdateResult.set_record_count(100);
- const std::string& string = doPutUpdateResult.SerializeAsString();
-
- auto buffer_ptr = std::make_shared(
- reinterpret_cast(string.data()), doPutUpdateResult.ByteSizeLong());
-
- ON_CALL(sql_client_, DoPut)
- .WillByDefault([&buffer_ptr](const FlightCallOptions& options,
- const FlightDescriptor& descriptor1,
- const std::shared_ptr& schema,
- std::unique_ptr* writer,
- std::unique_ptr* reader) {
- reader->reset(new FlightMetadataReaderMock(&buffer_ptr));
- writer->reset(new FlightStreamWriterMock());
-
- return Status::OK();
- });
-
- std::unique_ptr flight_info;
- std::unique_ptr writer;
- std::unique_ptr reader;
- EXPECT_CALL(sql_client_, DoPut(Ref(call_options_), descriptor, _, _, _));
-
- ASSERT_OK_AND_ASSIGN(auto num_rows, sql_client_.ExecuteUpdate(call_options_, query));
-
- ASSERT_EQ(num_rows, 100);
-}
-
-TEST_F(TestFlightSqlClient, TestGetSqlInfo) {
- std::vector sql_info{pb::sql::SqlInfo::FLIGHT_SQL_SERVER_NAME,
- pb::sql::SqlInfo::FLIGHT_SQL_SERVER_VERSION,
- pb::sql::SqlInfo::FLIGHT_SQL_SERVER_ARROW_VERSION};
- pb::sql::CommandGetSqlInfo command;
-
- for (const auto& info : sql_info) command.add_info(info);
- google::protobuf::Any any;
- any.PackFrom(command);
- const FlightDescriptor& descriptor = FlightDescriptor::Command(any.SerializeAsString());
-
- ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo);
- EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor));
-
- ASSERT_OK(sql_client_.GetSqlInfo(call_options_, sql_info));
-}
-
-template
-inline void AssertTestPreparedStatementExecuteUpdateOk(
- Func func, const std::shared_ptr* schema, FlightSqlClientMock& sql_client_) {
- const std::string query = "SELECT * FROM IRRELEVANT";
- int64_t expected_rows = 100L;
- pb::sql::DoPutUpdateResult result;
- result.set_record_count(expected_rows);
-
- ON_CALL(sql_client_, DoAction)
- .WillByDefault([&query, &schema](const FlightCallOptions& options,
- const Action& action,
- std::unique_ptr* results) {
- google::protobuf::Any command;
- pb::sql::ActionCreatePreparedStatementResult prepared_statement_result;
-
- prepared_statement_result.set_prepared_statement_handle(query);
-
- if (schema != NULLPTR) {
- std::shared_ptr schema_buffer;
- const arrow::Result>& result =
- arrow::ipc::SerializeSchema(**schema);
-
- ARROW_ASSIGN_OR_RAISE(schema_buffer, result);
- prepared_statement_result.set_parameter_schema(schema_buffer->ToString());
- }
-
- command.PackFrom(prepared_statement_result);
- *results = std::unique_ptr(new SimpleResultStream(
- {Result{Buffer::FromString(command.SerializeAsString())}}));
-
- return Status::OK();
- });
- EXPECT_CALL(sql_client_, DoAction(_, _, _)).Times(2);
-
- auto buffer = Buffer::FromString(result.SerializeAsString());
- ON_CALL(sql_client_, DoPut)
- .WillByDefault([&buffer](const FlightCallOptions& options,
- const FlightDescriptor& descriptor1,
- const std::shared_ptr& schema,
- std::unique_ptr* writer,
- std::unique_ptr* reader) {
- reader->reset(new FlightMetadataReaderMock(&buffer));
- writer->reset(new FlightStreamWriterMock());
- return Status::OK();
- });
- if (schema == NULLPTR) {
- EXPECT_CALL(sql_client_, DoPut(_, _, _, _, _));
- } else {
- EXPECT_CALL(sql_client_, DoPut(_, _, *schema, _, _));
- }
-
- ASSERT_OK_AND_ASSIGN(auto prepared_statement, sql_client_.Prepare({}, query));
- func(prepared_statement, sql_client_, schema, expected_rows);
- ASSERT_OK_AND_ASSIGN(auto rows, prepared_statement->ExecuteUpdate());
- ASSERT_EQ(expected_rows, rows);
- ASSERT_OK(prepared_statement->Close());
-}
-
-TEST_F(TestFlightSqlClient, TestPreparedStatementExecuteUpdateNoParameterBinding) {
- AssertTestPreparedStatementExecuteUpdateOk(
- [](const std::shared_ptr& prepared_statement,
- FlightSqlClient& sql_client_, const std::shared_ptr* schema,
- const int64_t& row_count) {},
- NULLPTR, sql_client_);
-}
-
-TEST_F(TestFlightSqlClient, TestPreparedStatementExecuteUpdateWithParameterBinding) {
- const auto schema = arrow::schema(
- {arrow::field("field0", arrow::utf8()), arrow::field("field1", arrow::uint8())});
- AssertTestPreparedStatementExecuteUpdateOk(
- [](const std::shared_ptr& prepared_statement,
- FlightSqlClient& sql_client_, const std::shared_ptr* schema,
- const int64_t& row_count) {
- auto string_array =
- ArrayFromJSON(utf8(), R"(["Lorem", "Ipsum", "Foo", "Bar", "Baz"])");
- auto uint8_array = ArrayFromJSON(uint8(), R"([0, 10, 15, 20, 25])");
- std::shared_ptr recordBatch =
- RecordBatch::Make(*schema, row_count, {string_array, uint8_array});
- ASSERT_OK(prepared_statement->SetParameters(recordBatch));
- },
- &schema, sql_client_);
-}
-
-} // namespace sql
-} // namespace flight
-} // namespace arrow
diff --git a/cpp/src/arrow/flight/sql/example/sqlite_server.cc b/cpp/src/arrow/flight/sql/example/sqlite_server.cc
index 9997d970a6e..f9ec864caba 100644
--- a/cpp/src/arrow/flight/sql/example/sqlite_server.cc
+++ b/cpp/src/arrow/flight/sql/example/sqlite_server.cc
@@ -17,7 +17,6 @@
#include "arrow/flight/sql/example/sqlite_server.h"
-#include
#define BOOST_NO_CXX98_FUNCTION_BASE // ARROW-17805
#include
#include
@@ -26,24 +25,32 @@
#include
#include
+#include
+
+#include "arrow/array/builder_binary.h"
#include "arrow/flight/sql/example/sqlite_sql_info.h"
#include "arrow/flight/sql/example/sqlite_statement.h"
#include "arrow/flight/sql/example/sqlite_statement_batch_reader.h"
#include "arrow/flight/sql/example/sqlite_tables_schema_batch_reader.h"
#include "arrow/flight/sql/example/sqlite_type_info.h"
#include "arrow/flight/sql/server.h"
+#include "arrow/scalar.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
namespace arrow {
namespace flight {
namespace sql {
namespace example {
+using arrow::internal::checked_cast;
+
namespace {
std::string PrepareQueryForGetTables(const GetTables& command) {
std::stringstream table_query;
- table_query << "SELECT null as catalog_name, null as schema_name, name as "
+ table_query << "SELECT 'main' as catalog_name, null as schema_name, name as "
"table_name, type as table_type FROM sqlite_master where 1=1";
if (command.catalog.has_value()) {
@@ -77,50 +84,25 @@ std::string PrepareQueryForGetTables(const GetTables& command) {
return table_query.str();
}
-Status SetParametersOnSQLiteStatement(sqlite3_stmt* stmt, FlightMessageReader* reader) {
+template
+Status SetParametersOnSQLiteStatement(SqliteStatement* statement,
+ FlightMessageReader* reader, Callback callback) {
+ sqlite3_stmt* stmt = statement->GetSqlite3Stmt();
while (true) {
ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, reader->Next());
- std::shared_ptr& record_batch = chunk.data;
- if (record_batch == nullptr) break;
+ if (chunk.data == nullptr) break;
- const int64_t num_rows = record_batch->num_rows();
- const int& num_columns = record_batch->num_columns();
+ const int64_t num_rows = chunk.data->num_rows();
+ if (num_rows == 0) continue;
+ ARROW_RETURN_NOT_OK(statement->SetParameters({std::move(chunk.data)}));
for (int i = 0; i < num_rows; ++i) {
- for (int c = 0; c < num_columns; ++c) {
- const std::shared_ptr& column = record_batch->column(c);
- ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, column->GetScalar(i));
-
- auto& holder = static_cast(*scalar).value;
-
- switch (holder->type->id()) {
- case Type::INT64: {
- int64_t value = static_cast(*holder).value;
- sqlite3_bind_int64(stmt, c + 1, value);
- break;
- }
- case Type::FLOAT: {
- double value = static_cast(*holder).value;
- sqlite3_bind_double(stmt, c + 1, value);
- break;
- }
- case Type::STRING: {
- std::shared_ptr buffer = static_cast(*holder).value;
- sqlite3_bind_text(stmt, c + 1, reinterpret_cast(buffer->data()),
- static_cast(buffer->size()), SQLITE_TRANSIENT);
- break;
- }
- case Type::BINARY: {
- std::shared_ptr buffer = static_cast(*holder).value;
- sqlite3_bind_blob(stmt, c + 1, buffer->data(),
- static_cast(buffer->size()), SQLITE_TRANSIENT);
- break;
- }
- default:
- return Status::Invalid("Received unsupported data type: ",
- holder->type->ToString());
- }
+ if (sqlite3_clear_bindings(stmt) != SQLITE_OK) {
+ return Status::Invalid("Failed to reset bindings on row ", i, ": ",
+ sqlite3_errmsg(statement->db()));
}
+ ARROW_RETURN_NOT_OK(statement->Bind(/*batch_index=*/0, i));
+ ARROW_RETURN_NOT_OK(callback());
}
}
@@ -183,8 +165,8 @@ std::string PrepareQueryForGetImportedOrExportedKeys(const std::string& filter)
} // namespace
-std::shared_ptr GetArrowType(const char* sqlite_type) {
- if (sqlite_type == NULLPTR) {
+arrow::Result> GetArrowType(const char* sqlite_type) {
+ if (sqlite_type == nullptr || std::strlen(sqlite_type) == 0) {
// SQLite may not know the column type yet.
return null();
}
@@ -199,9 +181,8 @@ std::shared_ptr GetArrowType(const char* sqlite_type) {
boost::istarts_with(sqlite_type, "char") ||
boost::istarts_with(sqlite_type, "varchar")) {
return utf8();
- } else {
- throw std::invalid_argument("Invalid SQLite type: " + std::string(sqlite_type));
}
+ return Status::Invalid("Invalid SQLite type: ", sqlite_type);
}
int32_t GetSqlTypeFromTypeName(const char* sqlite_type) {
@@ -245,13 +226,17 @@ class SQLiteFlightSqlServer::Impl {
}
arrow::Result GetConnection(const std::string& transaction_id) {
- if (transaction_id.empty()) return db_;
+ if (transaction_id.empty()) {
+ ARROW_LOG(INFO) << "Using default connection";
+ return db_;
+ }
std::lock_guard guard(mutex_);
auto it = open_transactions_.find(transaction_id);
if (it == open_transactions_.end()) {
return Status::KeyError("Unknown transaction ID: ", transaction_id);
}
+ ARROW_LOG(INFO) << "Using connection for transaction " << transaction_id;
return it->second;
}
@@ -346,18 +331,17 @@ class SQLiteFlightSqlServer::Impl {
arrow::Result> DoGetCatalogs(
const ServerCallContext& context) {
- // As SQLite doesn't support catalogs, this will return an empty record batch.
-
+ // https://www.sqlite.org/cli.html
+ // > The ".databases" command shows a list of all databases open
+ // > in the current connection. There will always be at least
+ // > 2. The first one is "main", the original database opened.
const std::shared_ptr& schema = SqlSchema::GetCatalogsSchema();
-
StringBuilder catalog_name_builder;
+ ARROW_RETURN_NOT_OK(catalog_name_builder.Append("main"));
ARROW_ASSIGN_OR_RAISE(auto catalog_name, catalog_name_builder.Finish());
-
- const std::shared_ptr& batch =
- RecordBatch::Make(schema, 0, {catalog_name});
-
+ std::shared_ptr batch =
+ RecordBatch::Make(schema, 1, {std::move(catalog_name)});
ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({batch}));
-
return std::make_unique(reader);
}
@@ -369,20 +353,26 @@ class SQLiteFlightSqlServer::Impl {
arrow::Result> DoGetDbSchemas(
const ServerCallContext& context, const GetDbSchemas& command) {
- // As SQLite doesn't support schemas, this will return an empty record batch.
-
+ // SQLite doesn't support schemas, so pretend we have a single
+ // unnamed schema.
const std::shared_ptr& schema = SqlSchema::GetDbSchemasSchema();
-
StringBuilder catalog_name_builder;
- ARROW_ASSIGN_OR_RAISE(auto catalog_name, catalog_name_builder.Finish());
StringBuilder schema_name_builder;
- ARROW_ASSIGN_OR_RAISE(auto schema_name, schema_name_builder.Finish());
- const std::shared_ptr& batch =
- RecordBatch::Make(schema, 0, {catalog_name, schema_name});
+ int64_t length = 0;
+ // XXX: we don't really implement the full pattern match here
+ if ((!command.catalog || command.catalog == "main") &&
+ (!command.db_schema_filter_pattern || command.db_schema_filter_pattern == "%")) {
+ ARROW_RETURN_NOT_OK(catalog_name_builder.Append("main"));
+ ARROW_RETURN_NOT_OK(schema_name_builder.AppendNull());
+ length++;
+ }
+ ARROW_ASSIGN_OR_RAISE(auto catalog_name, catalog_name_builder.Finish());
+ ARROW_ASSIGN_OR_RAISE(auto schema_name, schema_name_builder.Finish());
+ std::shared_ptr batch = RecordBatch::Make(
+ schema, length, {std::move(catalog_name), std::move(schema_name)});
ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({batch}));
-
return std::make_unique(reader);
}
@@ -392,6 +382,7 @@ class SQLiteFlightSqlServer::Impl {
std::vector endpoints{FlightEndpoint{{descriptor.cmd}, {}}};
bool include_schema = command.include_schema;
+ ARROW_LOG(INFO) << "GetTables include_schema=" << include_schema;
ARROW_ASSIGN_OR_RAISE(
auto result,
@@ -399,12 +390,13 @@ class SQLiteFlightSqlServer::Impl {
: *SqlSchema::GetTablesSchema(),
descriptor, endpoints, -1, -1))
- return std::make_unique(result);
+ return std::make_unique(std::move(result));
}
arrow::Result> DoGetTables(
const ServerCallContext& context, const GetTables& command) {
std::string query = PrepareQueryForGetTables(command);
+ ARROW_LOG(INFO) << "GetTables: " << query;
std::shared_ptr statement;
ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db_, query));
@@ -426,6 +418,7 @@ class SQLiteFlightSqlServer::Impl {
const StatementUpdate& command) {
const std::string& sql = command.query;
ARROW_ASSIGN_OR_RAISE(auto db, GetConnection(command.transaction_id));
+ ARROW_LOG(INFO) << "Executing update: " << sql;
ARROW_ASSIGN_OR_RAISE(auto statement, SqliteStatement::Create(db, sql));
return statement->ExecuteUpdate();
}
@@ -433,11 +426,16 @@ class SQLiteFlightSqlServer::Impl {
arrow::Result CreatePreparedStatement(
const ServerCallContext& context,
const ActionCreatePreparedStatementRequest& request) {
- std::lock_guard guard(mutex_);
std::shared_ptr statement;
- ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db_, request.query));
+ ARROW_ASSIGN_OR_RAISE(auto db, GetConnection(request.transaction_id));
+ ARROW_LOG(INFO) << "Creating prepared statement: " << request.query;
+ ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db, request.query));
std::string handle = GenerateRandomString();
- prepared_statements_[handle] = statement;
+
+ {
+ std::lock_guard guard(mutex_);
+ prepared_statements_[handle] = statement;
+ }
ARROW_ASSIGN_OR_RAISE(auto dataset_schema, statement->GetSchema());
@@ -525,10 +523,9 @@ class SQLiteFlightSqlServer::Impl {
const std::string& prepared_statement_handle = command.prepared_statement_handle;
ARROW_ASSIGN_OR_RAISE(auto statement,
GetStatementByHandle(prepared_statement_handle));
-
- sqlite3_stmt* stmt = statement->GetSqlite3Stmt();
- ARROW_RETURN_NOT_OK(SetParametersOnSQLiteStatement(stmt, reader));
-
+ // Save params here and execute later
+ ARROW_ASSIGN_OR_RAISE(auto batches, reader->ToRecordBatches());
+ ARROW_RETURN_NOT_OK(statement->SetParameters(std::move(batches)));
return Status::OK();
}
@@ -536,13 +533,20 @@ class SQLiteFlightSqlServer::Impl {
const ServerCallContext& context, const PreparedStatementUpdate& command,
FlightMessageReader* reader) {
const std::string& prepared_statement_handle = command.prepared_statement_handle;
- ARROW_ASSIGN_OR_RAISE(auto statement,
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr statement,
GetStatementByHandle(prepared_statement_handle));
- sqlite3_stmt* stmt = statement->GetSqlite3Stmt();
- ARROW_RETURN_NOT_OK(SetParametersOnSQLiteStatement(stmt, reader));
-
- return statement->ExecuteUpdate();
+ int64_t rows_affected = 0;
+ if (sqlite3_bind_parameter_count(statement->GetSqlite3Stmt()) == 0) {
+ ARROW_ASSIGN_OR_RAISE(rows_affected, statement->ExecuteUpdate());
+ } else {
+ ARROW_RETURN_NOT_OK(SetParametersOnSQLiteStatement(statement.get(), reader, [&]() {
+ ARROW_ASSIGN_OR_RAISE(int64_t rows, statement->ExecuteUpdate());
+ rows_affected += rows;
+ return statement->Reset().status();
+ }));
+ }
+ return rows_affected;
}
arrow::Result> GetFlightInfoTableTypes(
@@ -712,6 +716,8 @@ class SQLiteFlightSqlServer::Impl {
ARROW_RETURN_NOT_OK(ExecuteSql(new_db, "BEGIN TRANSACTION"));
+ ARROW_LOG(INFO) << "Beginning transaction on " << handle;
+
std::lock_guard guard(mutex_);
open_transactions_[handle] = new_db;
return ActionBeginTransactionResult{std::move(handle)};
@@ -729,8 +735,10 @@ class SQLiteFlightSqlServer::Impl {
}
if (request.action == ActionEndTransactionRequest::kCommit) {
+ ARROW_LOG(INFO) << "Committing on " << request.transaction_id;
status = ExecuteSql(it->second, "COMMIT");
} else {
+ ARROW_LOG(INFO) << "Rolling back on " << request.transaction_id;
status = ExecuteSql(it->second, "ROLLBACK");
}
transaction = it->second;
diff --git a/cpp/src/arrow/flight/sql/example/sqlite_server.h b/cpp/src/arrow/flight/sql/example/sqlite_server.h
index 389a2d921bb..d8c84e36e68 100644
--- a/cpp/src/arrow/flight/sql/example/sqlite_server.h
+++ b/cpp/src/arrow/flight/sql/example/sqlite_server.h
@@ -19,13 +19,14 @@
#include
+#include
#include
#include
-#include "arrow/api.h"
#include "arrow/flight/sql/example/sqlite_statement.h"
#include "arrow/flight/sql/example/sqlite_statement_batch_reader.h"
#include "arrow/flight/sql/server.h"
+#include "arrow/result.h"
namespace arrow {
namespace flight {
@@ -35,7 +36,7 @@ namespace example {
/// \brief Convert a column type to a ArrowType.
/// \param sqlite_type the sqlite type.
/// \return The equivalent ArrowType.
-std::shared_ptr GetArrowType(const char* sqlite_type);
+arrow::Result> GetArrowType(const char* sqlite_type);
/// \brief Convert a column type name to SQLite type.
/// \param type_name the type name.
diff --git a/cpp/src/arrow/flight/sql/example/sqlite_statement.cc b/cpp/src/arrow/flight/sql/example/sqlite_statement.cc
index d0810e2db02..85686ad8f0d 100644
--- a/cpp/src/arrow/flight/sql/example/sqlite_statement.cc
+++ b/cpp/src/arrow/flight/sql/example/sqlite_statement.cc
@@ -17,19 +17,25 @@
#include "arrow/flight/sql/example/sqlite_statement.h"
-#include
+#include
-#define BOOST_NO_CXX98_FUNCTION_BASE // ARROW-17805
-#include
+#include
+#include "arrow/array/array_base.h"
#include "arrow/flight/sql/column_metadata.h"
#include "arrow/flight/sql/example/sqlite_server.h"
+#include "arrow/scalar.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
namespace arrow {
namespace flight {
namespace sql {
namespace example {
+using arrow::internal::checked_cast;
+
std::shared_ptr GetDataTypeFromSqliteType(const int column_type) {
switch (column_type) {
case SQLITE_INTEGER:
@@ -119,7 +125,7 @@ arrow::Result> SqliteStatement::GetSchema() const {
// Try to retrieve column type from sqlite3_column_decltype
const char* column_decltype = sqlite3_column_decltype(stmt_, i);
if (column_decltype != NULLPTR) {
- data_type = GetArrowType(column_decltype);
+ ARROW_ASSIGN_OR_RAISE(data_type, GetArrowType(column_decltype));
} else {
// If it can not determine the actual column type, return a dense_union type
// covering any type SQLite supports.
@@ -160,10 +166,85 @@ arrow::Result SqliteStatement::Reset() {
sqlite3_stmt* SqliteStatement::GetSqlite3Stmt() const { return stmt_; }
arrow::Result SqliteStatement::ExecuteUpdate() {
- ARROW_RETURN_NOT_OK(Step());
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(int rc, Step());
+ if (rc == SQLITE_DONE) break;
+ }
return sqlite3_changes(db_);
}
+Status SqliteStatement::SetParameters(
+ std::vector> parameters) {
+ const int num_params = sqlite3_bind_parameter_count(stmt_);
+ for (const auto& batch : parameters) {
+ if (batch->num_columns() != num_params) {
+ return Status::Invalid("Expected ", num_params, " parameters, but got ",
+ batch->num_columns());
+ }
+ }
+ parameters_ = std::move(parameters);
+ auto end = std::remove_if(
+ parameters_.begin(), parameters_.end(),
+ [](const std::shared_ptr& batch) { return batch->num_rows() == 0; });
+ parameters_.erase(end, parameters_.end());
+ return Status::OK();
+}
+
+Status SqliteStatement::Bind(size_t batch_index, int64_t row_index) {
+ if (batch_index >= parameters_.size()) {
+ return Status::Invalid("Cannot bind to batch ", batch_index);
+ }
+ const RecordBatch& batch = *parameters_[batch_index];
+ if (row_index < 0 || row_index >= batch.num_rows()) {
+ return Status::Invalid("Cannot bind to row ", row_index, " in batch ", batch_index);
+ }
+
+ if (sqlite3_clear_bindings(stmt_) != SQLITE_OK) {
+ return Status::Invalid("Failed to reset bindings: ", sqlite3_errmsg(db_));
+ }
+ for (int c = 0; c < batch.num_columns(); ++c) {
+ const std::shared_ptr& column = batch.column(c);
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, column->GetScalar(row_index));
+ if (scalar->type->id() == Type::DENSE_UNION) {
+ scalar = checked_cast(*scalar).value;
+ }
+
+ int rc = 0;
+ if (!scalar->is_valid) {
+ rc = sqlite3_bind_null(stmt_, c + 1);
+ continue;
+ } else {
+ switch (scalar->type->id()) {
+ case Type::INT64: {
+ int64_t value = checked_cast(*scalar).value;
+ rc = sqlite3_bind_int64(stmt_, c + 1, value);
+ break;
+ }
+ case Type::FLOAT: {
+ float value = checked_cast(*scalar).value;
+ rc = sqlite3_bind_double(stmt_, c + 1, value);
+ break;
+ }
+ case Type::STRING: {
+ const std::shared_ptr& buffer =
+ checked_cast(*scalar).value;
+ rc = sqlite3_bind_text(stmt_, c + 1,
+ reinterpret_cast(buffer->data()),
+ static_cast(buffer->size()), SQLITE_TRANSIENT);
+ break;
+ }
+ default:
+ return Status::Invalid("Received unsupported data type: ", *scalar->type);
+ }
+ }
+ if (rc != SQLITE_OK) {
+ return Status::UnknownError("Failed to bind parameter: ", sqlite3_errmsg(db_));
+ }
+ }
+
+ return Status::OK();
+}
+
} // namespace example
} // namespace sql
} // namespace flight
diff --git a/cpp/src/arrow/flight/sql/example/sqlite_statement.h b/cpp/src/arrow/flight/sql/example/sqlite_statement.h
index b31eab506fa..1c1c813718b 100644
--- a/cpp/src/arrow/flight/sql/example/sqlite_statement.h
+++ b/cpp/src/arrow/flight/sql/example/sqlite_statement.h
@@ -62,15 +62,25 @@ class SqliteStatement {
/// \brief Returns the underlying sqlite3_stmt.
/// \return A sqlite statement.
- sqlite3_stmt* GetSqlite3Stmt() const;
+ [[nodiscard]] sqlite3_stmt* GetSqlite3Stmt() const;
+
+ [[nodiscard]] sqlite3* db() const { return db_; }
/// \brief Executes an UPDATE, INSERT or DELETE statement.
/// \return The number of rows changed by execution.
arrow::Result ExecuteUpdate();
+ [[nodiscard]] const std::vector>& parameters()
+ const {
+ return parameters_;
+ }
+ Status SetParameters(std::vector> parameters);
+ Status Bind(size_t batch_index, int64_t row_index);
+
private:
sqlite3* db_;
sqlite3_stmt* stmt_;
+ std::vector> parameters_;
SqliteStatement(sqlite3* db, sqlite3_stmt* stmt) : db_(db), stmt_(stmt) {}
};
diff --git a/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.cc b/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.cc
index c247eb62875..27c72614c5d 100644
--- a/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.cc
+++ b/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.cc
@@ -54,10 +54,6 @@
case TYPE_CLASS##Type::type_id: { \
using c_type = typename TYPE_CLASS##Type::c_type; \
auto builder = reinterpret_cast(array_builder); \
- if (sqlite3_column_type(stmt_, i) == SQLITE_NULL) { \
- ARROW_RETURN_NOT_OK(builder->AppendNull()); \
- break; \
- } \
const sqlite3_int64 value = sqlite3_column_int64(STMT, COLUMN); \
ARROW_RETURN_NOT_OK(builder->Append(static_cast(value))); \
break; \
@@ -66,10 +62,6 @@
#define FLOAT_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \
case TYPE_CLASS##Type::type_id: { \
auto builder = reinterpret_cast(array_builder); \
- if (sqlite3_column_type(stmt_, i) == SQLITE_NULL) { \
- ARROW_RETURN_NOT_OK(builder->AppendNull()); \
- break; \
- } \
const double value = sqlite3_column_double(STMT, COLUMN); \
ARROW_RETURN_NOT_OK( \
builder->Append(static_cast(value))); \
@@ -82,7 +74,7 @@ namespace sql {
namespace example {
// Batch size for SQLite statement results
-static constexpr int kMaxBatchSize = 1024;
+static constexpr int32_t kMaxBatchSize = 16384;
std::shared_ptr SqliteStatementBatchReader::schema() const { return schema_; }
@@ -95,8 +87,12 @@ SqliteStatementBatchReader::SqliteStatementBatchReader(
arrow::Result>
SqliteStatementBatchReader::Create(const std::shared_ptr& statement_) {
+ ARROW_RETURN_NOT_OK(statement_->Reset());
+ if (!statement_->parameters().empty()) {
+ // If there are parameters, infer the schema after binding the first row
+ ARROW_RETURN_NOT_OK(statement_->Bind(0, 0));
+ }
ARROW_RETURN_NOT_OK(statement_->Step());
-
ARROW_ASSIGN_OR_RAISE(auto schema, statement_->GetSchema());
std::shared_ptr result(
@@ -108,10 +104,8 @@ SqliteStatementBatchReader::Create(const std::shared_ptr& state
arrow::Result>
SqliteStatementBatchReader::Create(const std::shared_ptr& statement,
const std::shared_ptr& schema) {
- std::shared_ptr result(
+ return std::shared_ptr(
new SqliteStatementBatchReader(statement, schema));
-
- return result;
}
Status SqliteStatementBatchReader::ReadNext(std::shared_ptr* out) {
@@ -127,61 +121,89 @@ Status SqliteStatementBatchReader::ReadNext(std::shared_ptr* out) {
ARROW_RETURN_NOT_OK(MakeBuilder(default_memory_pool(), field_type, &builders[i]));
}
- if (!already_executed_) {
- ARROW_ASSIGN_OR_RAISE(rc_, statement_->Reset());
- ARROW_ASSIGN_OR_RAISE(rc_, statement_->Step());
- already_executed_ = true;
- }
-
int64_t rows = 0;
- while (rows < kMaxBatchSize && rc_ == SQLITE_ROW) {
- rows++;
- for (int i = 0; i < num_fields; i++) {
- const std::shared_ptr& field = schema_->field(i);
- const std::shared_ptr& field_type = field->type();
- ArrayBuilder* array_builder = builders[i].get();
-
- // NOTE: This is not the optimal way of building Arrow vectors.
- // That would be to presize the builders to avoiding several resizing operations
- // when appending values and also to build one vector at a time.
- switch (field_type->id()) {
- // XXX This doesn't handle overflows when converting to the target
- // integer type.
- INT_BUILDER_CASE(Int64, stmt_, i)
- INT_BUILDER_CASE(UInt64, stmt_, i)
- INT_BUILDER_CASE(Int32, stmt_, i)
- INT_BUILDER_CASE(UInt32, stmt_, i)
- INT_BUILDER_CASE(Int16, stmt_, i)
- INT_BUILDER_CASE(UInt16, stmt_, i)
- INT_BUILDER_CASE(Int8, stmt_, i)
- INT_BUILDER_CASE(UInt8, stmt_, i)
- FLOAT_BUILDER_CASE(Double, stmt_, i)
- FLOAT_BUILDER_CASE(Float, stmt_, i)
- FLOAT_BUILDER_CASE(HalfFloat, stmt_, i)
- BINARY_BUILDER_CASE(Binary, stmt_, i)
- BINARY_BUILDER_CASE(LargeBinary, stmt_, i)
- STRING_BUILDER_CASE(String, stmt_, i)
- STRING_BUILDER_CASE(LargeString, stmt_, i)
- default:
- return Status::NotImplemented("Not implemented SQLite data conversion to ",
- field_type->name());
+ while (true) {
+ if (!already_executed_) {
+ ARROW_ASSIGN_OR_RAISE(rc_, statement_->Reset());
+ if (!statement_->parameters().empty()) {
+ if (batch_index_ >= statement_->parameters().size()) {
+ *out = nullptr;
+ break;
+ }
+ ARROW_RETURN_NOT_OK(statement_->Bind(batch_index_, row_index_));
}
+ ARROW_ASSIGN_OR_RAISE(rc_, statement_->Step());
+ already_executed_ = true;
}
- ARROW_ASSIGN_OR_RAISE(rc_, statement_->Step());
- }
+ while (rows < kMaxBatchSize && rc_ == SQLITE_ROW) {
+ rows++;
+ for (int i = 0; i < num_fields; i++) {
+ const std::shared_ptr& field = schema_->field(i);
+ const std::shared_ptr& field_type = field->type();
+ ArrayBuilder* array_builder = builders[i].get();
+
+ if (sqlite3_column_type(stmt_, i) == SQLITE_NULL) {
+ ARROW_RETURN_NOT_OK(array_builder->AppendNull());
+ continue;
+ }
+
+ switch (field_type->id()) {
+ // XXX This doesn't handle overflows when converting to the target
+ // integer type.
+ INT_BUILDER_CASE(Int64, stmt_, i)
+ INT_BUILDER_CASE(UInt64, stmt_, i)
+ INT_BUILDER_CASE(Int32, stmt_, i)
+ INT_BUILDER_CASE(UInt32, stmt_, i)
+ INT_BUILDER_CASE(Int16, stmt_, i)
+ INT_BUILDER_CASE(UInt16, stmt_, i)
+ INT_BUILDER_CASE(Int8, stmt_, i)
+ INT_BUILDER_CASE(UInt8, stmt_, i)
+ FLOAT_BUILDER_CASE(Double, stmt_, i)
+ FLOAT_BUILDER_CASE(Float, stmt_, i)
+ FLOAT_BUILDER_CASE(HalfFloat, stmt_, i)
+ BINARY_BUILDER_CASE(Binary, stmt_, i)
+ BINARY_BUILDER_CASE(LargeBinary, stmt_, i)
+ STRING_BUILDER_CASE(String, stmt_, i)
+ STRING_BUILDER_CASE(LargeString, stmt_, i)
+ default:
+ return Status::NotImplemented("Not implemented SQLite data conversion to ",
+ field_type->name());
+ }
+ }
- if (rows > 0) {
- std::vector> arrays(builders.size());
- for (int i = 0; i < num_fields; i++) {
- ARROW_RETURN_NOT_OK(builders[i]->Finish(&arrays[i]));
+ ARROW_ASSIGN_OR_RAISE(rc_, statement_->Step());
}
- *out = RecordBatch::Make(schema_, rows, arrays);
- } else {
- *out = NULLPTR;
- }
+ // If we still have bind parameters, bind again and retry
+ const std::vector>& params = statement_->parameters();
+ if (!params.empty() && rc_ == SQLITE_DONE && batch_index_ < params.size()) {
+ row_index_++;
+ if (row_index_ < params[batch_index_]->num_rows()) {
+ already_executed_ = false;
+ } else {
+ batch_index_++;
+ row_index_ = 0;
+ if (batch_index_ < params.size()) {
+ already_executed_ = false;
+ }
+ }
+
+ if (!already_executed_ && rows < kMaxBatchSize) continue;
+ }
+ if (rows > 0) {
+ std::vector> arrays(builders.size());
+ for (int i = 0; i < num_fields; i++) {
+ ARROW_RETURN_NOT_OK(builders[i]->Finish(&arrays[i]));
+ }
+
+ *out = RecordBatch::Make(schema_, rows, arrays);
+ } else {
+ *out = nullptr;
+ }
+ break;
+ }
return Status::OK();
}
diff --git a/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.h b/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.h
index 8a6bc6078e7..3fb9ae1f83c 100644
--- a/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.h
+++ b/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.h
@@ -55,6 +55,10 @@ class SqliteStatementBatchReader : public RecordBatchReader {
int rc_;
bool already_executed_;
+ // State for parameter binding
+ size_t batch_index_{0};
+ int64_t row_index_{0};
+
SqliteStatementBatchReader(std::shared_ptr statement,
std::shared_ptr schema);
};
diff --git a/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.cc b/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.cc
index 921dd13182e..55345ad477a 100644
--- a/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.cc
+++ b/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.cc
@@ -21,6 +21,7 @@
#include
+#include "arrow/array/builder_binary.h"
#include "arrow/flight/sql/column_metadata.h"
#include "arrow/flight/sql/example/sqlite_server.h"
#include "arrow/flight/sql/example/sqlite_statement.h"
@@ -80,16 +81,19 @@ Status SqliteTablesWithSchemaBatchReader::ReadNext(std::shared_ptr*
const ColumnMetadata& column_metadata = GetColumnMetadata(
GetSqlTypeFromTypeName(column_type), sqlite_table_name.c_str());
- column_fields.push_back(arrow::field(column_name, GetArrowType(column_type),
- nullable == 0,
+ std::shared_ptr arrow_type;
+ auto status = GetArrowType(column_type).Value(&arrow_type);
+ if (!status.ok()) {
+ return Status::NotImplemented("Unknown SQLite type '", column_type,
+ "' for column '", column_name, "' in table '",
+ table_name, "': ", status);
+ }
+ column_fields.push_back(arrow::field(column_name, arrow_type, nullable == 0,
column_metadata.metadata_map()));
}
}
- const arrow::Result>& value =
- ipc::SerializeSchema(*arrow::schema(column_fields));
-
- std::shared_ptr schema_buffer;
- ARROW_ASSIGN_OR_RAISE(schema_buffer, value);
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr schema_buffer,
+ ipc::SerializeSchema(*arrow::schema(column_fields)));
column_fields.clear();
ARROW_RETURN_NOT_OK(schema_builder.Append(::std::string_view(*schema_buffer)));
diff --git a/cpp/src/arrow/flight/sql/server.cc b/cpp/src/arrow/flight/sql/server.cc
index 80112c2a44d..7f6d9b75a88 100644
--- a/cpp/src/arrow/flight/sql/server.cc
+++ b/cpp/src/arrow/flight/sql/server.cc
@@ -1094,41 +1094,49 @@ arrow::Result FlightSqlServerBase::DoPutCommandSubstraitPlan(
return Status::NotImplemented("DoPutCommandSubstraitPlan not implemented");
}
-std::shared_ptr SqlSchema::GetCatalogsSchema() {
- return arrow::schema({field("catalog_name", utf8(), false)});
+const std::shared_ptr& SqlSchema::GetCatalogsSchema() {
+ static std::shared_ptr kSchema =
+ arrow::schema({field("catalog_name", utf8(), false)});
+ return kSchema;
}
-std::shared_ptr SqlSchema::GetDbSchemasSchema() {
- return arrow::schema(
+const std::shared_ptr& SqlSchema::GetDbSchemasSchema() {
+ static std::shared_ptr kSchema = arrow::schema(
{field("catalog_name", utf8()), field("db_schema_name", utf8(), false)});
+ return kSchema;
}
-std::shared_ptr SqlSchema::GetTablesSchema() {
- return arrow::schema({field("catalog_name", utf8()), field("db_schema_name", utf8()),
- field("table_name", utf8(), false),
- field("table_type", utf8(), false)});
+const std::shared_ptr& SqlSchema::GetTablesSchema() {
+ static std::shared_ptr kSchema = arrow::schema(
+ {field("catalog_name", utf8()), field("db_schema_name", utf8()),
+ field("table_name", utf8(), false), field("table_type", utf8(), false)});
+ return kSchema;
}
-std::shared_ptr SqlSchema::GetTablesSchemaWithIncludedSchema() {
- return arrow::schema({field("catalog_name", utf8()), field("db_schema_name", utf8()),
- field("table_name", utf8(), false),
- field("table_type", utf8(), false),
- field("table_schema", binary(), false)});
+const std::shared_ptr& SqlSchema::GetTablesSchemaWithIncludedSchema() {
+ static std::shared_ptr kSchema = arrow::schema(
+ {field("catalog_name", utf8()), field("db_schema_name", utf8()),
+ field("table_name", utf8(), false), field("table_type", utf8(), false),
+ field("table_schema", binary(), false)});
+ return kSchema;
}
-std::shared_ptr SqlSchema::GetTableTypesSchema() {
- return arrow::schema({field("table_type", utf8(), false)});
+const std::shared_ptr& SqlSchema::GetTableTypesSchema() {
+ static std::shared_ptr kSchema =
+ arrow::schema({field("table_type", utf8(), false)});
+ return kSchema;
}
-std::shared_ptr SqlSchema::GetPrimaryKeysSchema() {
- return arrow::schema(
+const std::shared_ptr& SqlSchema::GetPrimaryKeysSchema() {
+ static std::shared_ptr kSchema = arrow::schema(
{field("catalog_name", utf8()), field("db_schema_name", utf8()),
field("table_name", utf8(), false), field("column_name", utf8(), false),
field("key_sequence", int32(), false), field("key_name", utf8())});
+ return kSchema;
}
-std::shared_ptr GetImportedExportedKeysAndCrossReferenceSchema() {
- return arrow::schema(
+const std::shared_ptr& GetImportedExportedKeysAndCrossReferenceSchema() {
+ static std::shared_ptr kSchema = arrow::schema(
{field("pk_catalog_name", utf8(), true), field("pk_db_schema_name", utf8(), true),
field("pk_table_name", utf8(), false), field("pk_column_name", utf8(), false),
field("fk_catalog_name", utf8(), true), field("fk_db_schema_name", utf8(), true),
@@ -1136,35 +1144,44 @@ std::shared_ptr GetImportedExportedKeysAndCrossReferenceSchema() {
field("key_sequence", int32(), false), field("fk_key_name", utf8(), true),
field("pk_key_name", utf8(), true), field("update_rule", uint8(), false),
field("delete_rule", uint8(), false)});
+ return kSchema;
}
-std::shared_ptr SqlSchema::GetImportedKeysSchema() {
- return GetImportedExportedKeysAndCrossReferenceSchema();
+const std::shared_ptr& SqlSchema::GetImportedKeysSchema() {
+ static std::shared_ptr kSchema =
+ GetImportedExportedKeysAndCrossReferenceSchema();
+ return kSchema;
}
-std::shared_ptr SqlSchema::GetExportedKeysSchema() {
- return GetImportedExportedKeysAndCrossReferenceSchema();
+const std::shared_ptr& SqlSchema::GetExportedKeysSchema() {
+ static std::shared_ptr kSchema =
+ GetImportedExportedKeysAndCrossReferenceSchema();
+ return kSchema;
}
-std::shared_ptr SqlSchema::GetCrossReferenceSchema() {
- return GetImportedExportedKeysAndCrossReferenceSchema();
+const std::shared_ptr& SqlSchema::GetCrossReferenceSchema() {
+ static std::shared_ptr kSchema =
+ GetImportedExportedKeysAndCrossReferenceSchema();
+ return kSchema;
}
-std::shared_ptr SqlSchema::GetSqlInfoSchema() {
- return arrow::schema({field("info_name", uint32(), false),
- field("value",
- dense_union({field("string_value", utf8(), false),
- field("bool_value", boolean(), false),
- field("bigint_value", int64(), false),
- field("int32_bitmask", int32(), false),
- field("string_list", list(utf8()), false),
- field("int32_to_int32_list_map",
- map(int32(), list(int32())), false)}),
- false)});
+const std::shared_ptr& SqlSchema::GetSqlInfoSchema() {
+ static std::shared_ptr kSchema =
+ arrow::schema({field("info_name", uint32(), false),
+ field("value",
+ dense_union({field("string_value", utf8(), false),
+ field("bool_value", boolean(), false),
+ field("bigint_value", int64(), false),
+ field("int32_bitmask", int32(), false),
+ field("string_list", list(utf8()), false),
+ field("int32_to_int32_list_map",
+ map(int32(), list(int32())), false)}),
+ false)});
+ return kSchema;
}
-std::shared_ptr SqlSchema::GetXdbcTypeInfoSchema() {
- return arrow::schema({
+const std::shared_ptr& SqlSchema::GetXdbcTypeInfoSchema() {
+ static std::shared_ptr kSchema = arrow::schema({
field("type_name", utf8(), false),
field("data_type", int32(), false),
field("column_size", int32()),
@@ -1185,6 +1202,7 @@ std::shared_ptr SqlSchema::GetXdbcTypeInfoSchema() {
field("num_prec_radix", int32()),
field("interval_precision", int32()),
});
+ return kSchema;
}
} // namespace sql
} // namespace flight
diff --git a/cpp/src/arrow/flight/sql/server.h b/cpp/src/arrow/flight/sql/server.h
index 0fc8b714865..65f6670171d 100644
--- a/cpp/src/arrow/flight/sql/server.h
+++ b/cpp/src/arrow/flight/sql/server.h
@@ -686,50 +686,50 @@ class ARROW_FLIGHT_SQL_EXPORT SqlSchema {
public:
/// \brief Get the Schema used on GetCatalogs response.
/// \return The default schema template.
- static std::shared_ptr GetCatalogsSchema();
+ static const std::shared_ptr& GetCatalogsSchema();
/// \brief Get the Schema used on GetDbSchemas response.
/// \return The default schema template.
- static std::shared_ptr GetDbSchemasSchema();
+ static const std::shared_ptr& GetDbSchemasSchema();
/// \brief Get the Schema used on GetTables response when included schema
/// flags is set to false.
/// \return The default schema template.
- static std::shared_ptr GetTablesSchema();
+ static const std::shared_ptr& GetTablesSchema();
/// \brief Get the Schema used on GetTables response when included schema
/// flags is set to true.
/// \return The default schema template.
- static std::shared_ptr GetTablesSchemaWithIncludedSchema();
+ static const std::shared_ptr& GetTablesSchemaWithIncludedSchema();
/// \brief Get the Schema used on GetTableTypes response.
/// \return The default schema template.
- static std::shared_ptr GetTableTypesSchema();
+ static const std::shared_ptr& GetTableTypesSchema();
/// \brief Get the Schema used on GetPrimaryKeys response when included schema
/// flags is set to true.
/// \return The default schema template.
- static std::shared_ptr GetPrimaryKeysSchema();
+ static const std::shared_ptr& GetPrimaryKeysSchema();
/// \brief Get the Schema used on GetImportedKeys response.
/// \return The default schema template.
- static std::shared_ptr GetExportedKeysSchema();
+ static const std::shared_ptr& GetExportedKeysSchema();
/// \brief Get the Schema used on GetImportedKeys response.
/// \return The default schema template.
- static std::shared_ptr GetImportedKeysSchema();
+ static const std::shared_ptr& GetImportedKeysSchema();
/// \brief Get the Schema used on GetCrossReference response.
/// \return The default schema template.
- static std::shared_ptr