diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt index d5134f3e8b3..628b02b9d28 100644 --- a/cpp/src/arrow/flight/sql/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/CMakeLists.txt @@ -98,14 +98,6 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) set(ARROW_FLIGHT_SQL_TEST_LIBS ${SQLite3_LIBRARIES}) set(ARROW_FLIGHT_SQL_ACERO_SRCS example/acero_server.cc) - if(NOT MSVC AND NOT MINGW) - # ARROW-16902: getting Protobuf generated code to have all the - # proper dllexport/dllimport declarations is difficult, since - # protoc does not insert them everywhere needed to satisfy both - # MinGW and MSVC, and the Protobuf team recommends against it - list(APPEND ARROW_FLIGHT_SQL_TEST_SRCS client_test.cc) - endif() - if(ARROW_COMPUTE AND ARROW_PARQUET AND ARROW_SUBSTRAIT) diff --git a/cpp/src/arrow/flight/sql/client.cc b/cpp/src/arrow/flight/sql/client.cc index 521cf9e8cd6..25bf8e384ef 100644 --- a/cpp/src/arrow/flight/sql/client.cc +++ b/cpp/src/arrow/flight/sql/client.cc @@ -16,6 +16,7 @@ // under the License. // Platform-specific defines +#include "arrow/flight/client.h" #include "arrow/flight/platform.h" #include "arrow/flight/sql/client.h" @@ -208,11 +209,12 @@ arrow::Result FlightSqlClient::ExecuteUpdate(const FlightCallOptions& o std::unique_ptr reader; ARROW_RETURN_NOT_OK(DoPut(options, descriptor, arrow::schema({}), &writer, &reader)); - std::shared_ptr metadata; ARROW_RETURN_NOT_OK(reader->ReadMetadata(&metadata)); ARROW_RETURN_NOT_OK(writer->Close()); + if (!metadata) return Status::IOError("Server did not send a response"); + flight_sql_pb::DoPutUpdateResult result; if (!result.ParseFromArray(metadata->data(), static_cast(metadata->size()))) { return Status::Invalid("Unable to parse DoPutUpdateResult"); @@ -535,6 +537,24 @@ arrow::Result> PreparedStatement::ParseRespon parameter_schema); } +arrow::Result> BindParameters(FlightClient* client, + const FlightCallOptions& options, + const FlightDescriptor& descriptor, + RecordBatchReader* params) { + ARROW_ASSIGN_OR_RAISE(auto stream, + client->DoPut(options, descriptor, params->schema())); + while (true) { + ARROW_ASSIGN_OR_RAISE(auto batch, params->Next()); + if (!batch) break; + ARROW_RETURN_NOT_OK(stream.writer->WriteRecordBatch(*batch)); + } + ARROW_RETURN_NOT_OK(stream.writer->DoneWriting()); + std::shared_ptr metadata; + ARROW_RETURN_NOT_OK(stream.reader->ReadMetadata(&metadata)); + ARROW_RETURN_NOT_OK(stream.writer->Close()); + return metadata; +} + arrow::Result> PreparedStatement::Execute( const FlightCallOptions& options) { if (is_closed_) { @@ -545,21 +565,11 @@ arrow::Result> PreparedStatement::Execute( command.set_prepared_statement_handle(handle_); ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor, GetFlightDescriptorForCommand(command)); - - if (parameter_binding_ && parameter_binding_->num_rows() > 0) { - std::unique_ptr writer; - std::unique_ptr reader; - ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, parameter_binding_->schema(), - &writer, &reader)); - - ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*parameter_binding_)); - ARROW_RETURN_NOT_OK(writer->DoneWriting()); - // Wait for the server to ack the result - std::shared_ptr buffer; - ARROW_RETURN_NOT_OK(reader->ReadMetadata(&buffer)); - ARROW_RETURN_NOT_OK(writer->Close()); + if (parameter_binding_) { + ARROW_ASSIGN_OR_RAISE(auto metadata, + BindParameters(client_->impl_.get(), options, descriptor, + parameter_binding_.get())); } - ARROW_ASSIGN_OR_RAISE(auto flight_info, client_->GetFlightInfo(options, descriptor)); return std::move(flight_info); } @@ -574,26 +584,19 @@ arrow::Result PreparedStatement::ExecuteUpdate( command.set_prepared_statement_handle(handle_); ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor, GetFlightDescriptorForCommand(command)); - std::unique_ptr writer; - std::unique_ptr reader; - - if (parameter_binding_ && parameter_binding_->num_rows() > 0) { - ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, parameter_binding_->schema(), - &writer, &reader)); - ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*parameter_binding_)); + std::shared_ptr metadata; + if (parameter_binding_) { + ARROW_ASSIGN_OR_RAISE(metadata, BindParameters(client_->impl_.get(), options, + descriptor, parameter_binding_.get())); } else { const std::shared_ptr schema = arrow::schema({}); - ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, schema, &writer, &reader)); - const ArrayVector columns; - const auto& record_batch = arrow::RecordBatch::Make(schema, 0, columns); - ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*record_batch)); + ARROW_ASSIGN_OR_RAISE(auto params, RecordBatchReader::Make({}, schema)); + ARROW_ASSIGN_OR_RAISE(metadata, BindParameters(client_->impl_.get(), options, + descriptor, params.get())); + } + if (!metadata) { + return Status::IOError("Server did not send a response to ", command.GetTypeName()); } - - ARROW_RETURN_NOT_OK(writer->DoneWriting()); - std::shared_ptr metadata; - ARROW_RETURN_NOT_OK(reader->ReadMetadata(&metadata)); - ARROW_RETURN_NOT_OK(writer->Close()); - flight_sql_pb::DoPutUpdateResult result; if (!result.ParseFromArray(metadata->data(), static_cast(metadata->size()))) { return Status::Invalid("Unable to parse DoPutUpdateResult object."); @@ -603,6 +606,13 @@ arrow::Result PreparedStatement::ExecuteUpdate( } Status PreparedStatement::SetParameters(std::shared_ptr parameter_binding) { + ARROW_ASSIGN_OR_RAISE(parameter_binding_, + RecordBatchReader::Make({std::move(parameter_binding)})); + return Status::OK(); +} + +Status PreparedStatement::SetParameters( + std::shared_ptr parameter_binding) { parameter_binding_ = std::move(parameter_binding); return Status::OK(); @@ -610,11 +620,11 @@ Status PreparedStatement::SetParameters(std::shared_ptr parameter_b bool PreparedStatement::IsClosed() const { return is_closed_; } -std::shared_ptr PreparedStatement::dataset_schema() const { +const std::shared_ptr& PreparedStatement::dataset_schema() const { return dataset_schema_; } -std::shared_ptr PreparedStatement::parameter_schema() const { +const std::shared_ptr& PreparedStatement::parameter_schema() const { return parameter_schema_; } diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h index db168847ed6..648f71563e9 100644 --- a/cpp/src/arrow/flight/sql/client.h +++ b/cpp/src/arrow/flight/sql/client.h @@ -392,17 +392,18 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement { /// \brief Retrieve the parameter schema from the query. /// \return The parameter schema from the query. - std::shared_ptr parameter_schema() const; + const std::shared_ptr& parameter_schema() const; /// \brief Retrieve the ResultSet schema from the query. /// \return The ResultSet schema from the query. - std::shared_ptr dataset_schema() const; + const std::shared_ptr& dataset_schema() const; /// \brief Set a RecordBatch that contains the parameters that will be bound. - /// \param parameter_binding The parameters that will be bound. - /// \return Status. Status SetParameters(std::shared_ptr parameter_binding); + /// \brief Set a RecordBatchReader that contains the parameters that will be bound. + Status SetParameters(std::shared_ptr parameter_binding); + /// \brief Re-request the result set schema from the server (should /// be identical to dataset_schema). arrow::Result> GetSchema( @@ -422,7 +423,7 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement { std::string handle_; std::shared_ptr dataset_schema_; std::shared_ptr parameter_schema_; - std::shared_ptr parameter_binding_; + std::shared_ptr parameter_binding_; bool is_closed_; }; diff --git a/cpp/src/arrow/flight/sql/client_test.cc b/cpp/src/arrow/flight/sql/client_test.cc deleted file mode 100644 index 984bf454816..00000000000 --- a/cpp/src/arrow/flight/sql/client_test.cc +++ /dev/null @@ -1,530 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Platform-specific defines -#include "arrow/flight/platform.h" - -#include "arrow/flight/client.h" - -#include -#include -#include - -#include - -#include "arrow/buffer.h" -#include "arrow/flight/sql/api.h" -#include "arrow/flight/sql/protocol_internal.h" -#include "arrow/testing/gtest_util.h" - -namespace pb = arrow::flight::protocol; -using ::testing::_; -using ::testing::Ref; - -namespace arrow { -namespace flight { -namespace sql { - -class FlightSqlClientMock : public FlightSqlClient { - public: - FlightSqlClientMock() : FlightSqlClient(nullptr) {} - - ~FlightSqlClientMock() = default; - - MOCK_METHOD(arrow::Result>, GetFlightInfo, - (const FlightCallOptions&, const FlightDescriptor&)); - MOCK_METHOD(Status, DoGet, - (const FlightCallOptions& options, const Ticket& ticket, - std::unique_ptr* stream)); - MOCK_METHOD(Status, DoPut, - (const FlightCallOptions&, const FlightDescriptor&, - const std::shared_ptr& schema, - std::unique_ptr*, - std::unique_ptr*)); - MOCK_METHOD(Status, DoAction, - (const FlightCallOptions& options, const Action& action, - std::unique_ptr* results)); -}; - -class TestFlightSqlClient : public ::testing::Test { - protected: - FlightSqlClientMock sql_client_; - FlightCallOptions call_options_; - - void SetUp() override {} - - void TearDown() override {} -}; - -class FlightMetadataReaderMock : public FlightMetadataReader { - public: - std::shared_ptr* buffer; - - explicit FlightMetadataReaderMock(std::shared_ptr* buffer) { - this->buffer = buffer; - } - - Status ReadMetadata(std::shared_ptr* out) override { - *out = *buffer; - return Status::OK(); - } -}; - -class FlightStreamWriterMock : public FlightStreamWriter { - public: - FlightStreamWriterMock() = default; - - Status DoneWriting() override { return Status::OK(); } - - Status WriteMetadata(std::shared_ptr app_metadata) override { - return Status::OK(); - } - - Status Begin(const std::shared_ptr& schema, - const ipc::IpcWriteOptions& options) override { - return Status::OK(); - } - - Status Begin(const std::shared_ptr& schema) override { - return MetadataRecordBatchWriter::Begin(schema); - } - - ipc::WriteStats stats() const override { return ipc::WriteStats(); } - - Status WriteWithMetadata(const RecordBatch& batch, - std::shared_ptr app_metadata) override { - return Status::OK(); - } - - Status Close() override { return Status::OK(); } - - Status WriteRecordBatch(const RecordBatch& batch) override { return Status::OK(); } -}; - -FlightDescriptor getDescriptor(google::protobuf::Message& command) { - google::protobuf::Any any; - any.PackFrom(command); - - const std::string& string = any.SerializeAsString(); - return FlightDescriptor::Command(string); -} - -auto ReturnEmptyFlightInfo = [](const FlightCallOptions& options, - const FlightDescriptor& descriptor) { - std::unique_ptr flight_info; - return flight_info; -}; - -TEST_F(TestFlightSqlClient, TestGetCatalogs) { - pb::sql::CommandGetCatalogs command; - FlightDescriptor descriptor = getDescriptor(command); - - ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); - EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); - - ASSERT_OK(sql_client_.GetCatalogs(call_options_)); -} - -TEST_F(TestFlightSqlClient, TestGetDbSchemas) { - std::string schema_filter_pattern = "schema_filter_pattern"; - std::string catalog = "catalog"; - - pb::sql::CommandGetDbSchemas command; - command.set_catalog(catalog); - command.set_db_schema_filter_pattern(schema_filter_pattern); - FlightDescriptor descriptor = getDescriptor(command); - - ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); - EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); - - ASSERT_OK(sql_client_.GetDbSchemas(call_options_, &catalog, &schema_filter_pattern)); -} - -TEST_F(TestFlightSqlClient, TestGetTables) { - std::string catalog = "catalog"; - std::string schema_filter_pattern = "schema_filter_pattern"; - std::string table_name_filter_pattern = "table_name_filter_pattern"; - bool include_schema = true; - std::vector table_types = {"type1", "type2"}; - - pb::sql::CommandGetTables command; - command.set_catalog(catalog); - command.set_db_schema_filter_pattern(schema_filter_pattern); - command.set_table_name_filter_pattern(table_name_filter_pattern); - command.set_include_schema(include_schema); - for (const std::string& table_type : table_types) { - command.add_table_types(table_type); - } - FlightDescriptor descriptor = getDescriptor(command); - - ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); - EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); - - ASSERT_OK(sql_client_.GetTables(call_options_, &catalog, &schema_filter_pattern, - &table_name_filter_pattern, include_schema, - &table_types)); -} - -TEST_F(TestFlightSqlClient, TestGetTableTypes) { - pb::sql::CommandGetTableTypes command; - FlightDescriptor descriptor = getDescriptor(command); - - ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); - EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); - - ASSERT_OK(sql_client_.GetTableTypes(call_options_)); -} - -TEST_F(TestFlightSqlClient, TestGetTypeInfo) { - pb::sql::CommandGetXdbcTypeInfo command; - FlightDescriptor descriptor = getDescriptor(command); - - ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); - EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); - - ASSERT_OK(sql_client_.GetXdbcTypeInfo(call_options_)); -} - -TEST_F(TestFlightSqlClient, TestGetExported) { - std::string catalog = "catalog"; - std::string schema = "schema"; - std::string table = "table"; - - pb::sql::CommandGetExportedKeys command; - command.set_catalog(catalog); - command.set_db_schema(schema); - command.set_table(table); - FlightDescriptor descriptor = getDescriptor(command); - - ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); - EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); - - TableRef table_ref = {std::make_optional(catalog), std::make_optional(schema), table}; - ASSERT_OK(sql_client_.GetExportedKeys(call_options_, table_ref)); -} - -TEST_F(TestFlightSqlClient, TestGetImported) { - std::string catalog = "catalog"; - std::string schema = "schema"; - std::string table = "table"; - - pb::sql::CommandGetImportedKeys command; - command.set_catalog(catalog); - command.set_db_schema(schema); - command.set_table(table); - FlightDescriptor descriptor = getDescriptor(command); - - ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); - EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); - - TableRef table_ref = {std::make_optional(catalog), std::make_optional(schema), table}; - ASSERT_OK(sql_client_.GetImportedKeys(call_options_, table_ref)); -} - -TEST_F(TestFlightSqlClient, TestGetPrimary) { - std::string catalog = "catalog"; - std::string schema = "schema"; - std::string table = "table"; - - pb::sql::CommandGetPrimaryKeys command; - command.set_catalog(catalog); - command.set_db_schema(schema); - command.set_table(table); - FlightDescriptor descriptor = getDescriptor(command); - - ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); - EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); - - TableRef table_ref = {std::make_optional(catalog), std::make_optional(schema), table}; - ASSERT_OK(sql_client_.GetPrimaryKeys(call_options_, table_ref)); -} - -TEST_F(TestFlightSqlClient, TestGetCrossReference) { - std::string pk_catalog = "pk_catalog"; - std::string pk_schema = "pk_schema"; - std::string pk_table = "pk_table"; - std::string fk_catalog = "fk_catalog"; - std::string fk_schema = "fk_schema"; - std::string fk_table = "fk_table"; - - pb::sql::CommandGetCrossReference command; - command.set_pk_catalog(pk_catalog); - command.set_pk_db_schema(pk_schema); - command.set_pk_table(pk_table); - command.set_fk_catalog(fk_catalog); - command.set_fk_db_schema(fk_schema); - command.set_fk_table(fk_table); - FlightDescriptor descriptor = getDescriptor(command); - - ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); - EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); - - TableRef pk_table_ref = {std::make_optional(pk_catalog), std::make_optional(pk_schema), - pk_table}; - TableRef fk_table_ref = {std::make_optional(fk_catalog), std::make_optional(fk_schema), - fk_table}; - ASSERT_OK(sql_client_.GetCrossReference(call_options_, pk_table_ref, fk_table_ref)); -} - -TEST_F(TestFlightSqlClient, TestExecute) { - std::string query = "query"; - - pb::sql::CommandStatementQuery command; - command.set_query(query); - FlightDescriptor descriptor = getDescriptor(command); - - ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); - EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); - - ASSERT_OK(sql_client_.Execute(call_options_, query)); -} - -TEST_F(TestFlightSqlClient, TestPreparedStatementExecute) { - const std::string query = "query"; - - ON_CALL(sql_client_, DoAction) - .WillByDefault([](const FlightCallOptions& options, const Action& action, - std::unique_ptr* results) { - google::protobuf::Any command; - - pb::sql::ActionCreatePreparedStatementResult prepared_statement_result; - - prepared_statement_result.set_prepared_statement_handle("query"); - - command.PackFrom(prepared_statement_result); - - *results = std::unique_ptr(new SimpleResultStream( - {Result{Buffer::FromString(command.SerializeAsString())}})); - - return Status::OK(); - }); - - EXPECT_CALL(sql_client_, DoAction(_, _, _)).Times(2); - - ASSERT_OK_AND_ASSIGN(auto prepared_statement, - sql_client_.Prepare(call_options_, query)); - - ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); - EXPECT_CALL(sql_client_, GetFlightInfo(_, _)); - - ASSERT_OK(prepared_statement->Execute()); -} - -TEST_F(TestFlightSqlClient, TestPreparedStatementExecuteParameterBinding) { - const std::string query = "query"; - - ON_CALL(sql_client_, DoAction) - .WillByDefault([](const FlightCallOptions& options, const Action& action, - std::unique_ptr* results) { - google::protobuf::Any command; - - pb::sql::ActionCreatePreparedStatementResult prepared_statement_result; - - prepared_statement_result.set_prepared_statement_handle("query"); - - auto schema = arrow::schema({arrow::field("id", int64())}); - - std::shared_ptr schema_buffer; - const arrow::Result>& result = - arrow::ipc::SerializeSchema(*schema); - - ARROW_ASSIGN_OR_RAISE(schema_buffer, result); - - prepared_statement_result.set_parameter_schema(schema_buffer->ToString()); - - command.PackFrom(prepared_statement_result); - - *results = std::unique_ptr(new SimpleResultStream( - {Result{Buffer::FromString(command.SerializeAsString())}})); - - return Status::OK(); - }); - - std::shared_ptr buffer_ptr; - ON_CALL(sql_client_, DoPut) - .WillByDefault([&buffer_ptr](const FlightCallOptions& options, - const FlightDescriptor& descriptor1, - const std::shared_ptr& schema, - std::unique_ptr* writer, - std::unique_ptr* reader) { - writer->reset(new FlightStreamWriterMock()); - reader->reset(new FlightMetadataReaderMock(&buffer_ptr)); - - return Status::OK(); - }); - - EXPECT_CALL(sql_client_, DoAction(_, _, _)).Times(2); - EXPECT_CALL(sql_client_, DoPut(_, _, _, _, _)); - - ASSERT_OK_AND_ASSIGN(auto prepared_statement, - sql_client_.Prepare(call_options_, query)); - - auto parameter_schema = prepared_statement->parameter_schema(); - - auto result = RecordBatchFromJSON(parameter_schema, "[[1]]"); - ASSERT_OK(prepared_statement->SetParameters(result)); - - ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); - EXPECT_CALL(sql_client_, GetFlightInfo(_, _)); - - ASSERT_OK(prepared_statement->Execute()); -} - -TEST_F(TestFlightSqlClient, TestExecuteUpdate) { - std::string query = "query"; - - pb::sql::CommandStatementUpdate command; - - command.set_query(query); - - google::protobuf::Any any; - any.PackFrom(command); - - const FlightDescriptor& descriptor = FlightDescriptor::Command(any.SerializeAsString()); - - pb::sql::DoPutUpdateResult doPutUpdateResult; - doPutUpdateResult.set_record_count(100); - const std::string& string = doPutUpdateResult.SerializeAsString(); - - auto buffer_ptr = std::make_shared( - reinterpret_cast(string.data()), doPutUpdateResult.ByteSizeLong()); - - ON_CALL(sql_client_, DoPut) - .WillByDefault([&buffer_ptr](const FlightCallOptions& options, - const FlightDescriptor& descriptor1, - const std::shared_ptr& schema, - std::unique_ptr* writer, - std::unique_ptr* reader) { - reader->reset(new FlightMetadataReaderMock(&buffer_ptr)); - writer->reset(new FlightStreamWriterMock()); - - return Status::OK(); - }); - - std::unique_ptr flight_info; - std::unique_ptr writer; - std::unique_ptr reader; - EXPECT_CALL(sql_client_, DoPut(Ref(call_options_), descriptor, _, _, _)); - - ASSERT_OK_AND_ASSIGN(auto num_rows, sql_client_.ExecuteUpdate(call_options_, query)); - - ASSERT_EQ(num_rows, 100); -} - -TEST_F(TestFlightSqlClient, TestGetSqlInfo) { - std::vector sql_info{pb::sql::SqlInfo::FLIGHT_SQL_SERVER_NAME, - pb::sql::SqlInfo::FLIGHT_SQL_SERVER_VERSION, - pb::sql::SqlInfo::FLIGHT_SQL_SERVER_ARROW_VERSION}; - pb::sql::CommandGetSqlInfo command; - - for (const auto& info : sql_info) command.add_info(info); - google::protobuf::Any any; - any.PackFrom(command); - const FlightDescriptor& descriptor = FlightDescriptor::Command(any.SerializeAsString()); - - ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); - EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); - - ASSERT_OK(sql_client_.GetSqlInfo(call_options_, sql_info)); -} - -template -inline void AssertTestPreparedStatementExecuteUpdateOk( - Func func, const std::shared_ptr* schema, FlightSqlClientMock& sql_client_) { - const std::string query = "SELECT * FROM IRRELEVANT"; - int64_t expected_rows = 100L; - pb::sql::DoPutUpdateResult result; - result.set_record_count(expected_rows); - - ON_CALL(sql_client_, DoAction) - .WillByDefault([&query, &schema](const FlightCallOptions& options, - const Action& action, - std::unique_ptr* results) { - google::protobuf::Any command; - pb::sql::ActionCreatePreparedStatementResult prepared_statement_result; - - prepared_statement_result.set_prepared_statement_handle(query); - - if (schema != NULLPTR) { - std::shared_ptr schema_buffer; - const arrow::Result>& result = - arrow::ipc::SerializeSchema(**schema); - - ARROW_ASSIGN_OR_RAISE(schema_buffer, result); - prepared_statement_result.set_parameter_schema(schema_buffer->ToString()); - } - - command.PackFrom(prepared_statement_result); - *results = std::unique_ptr(new SimpleResultStream( - {Result{Buffer::FromString(command.SerializeAsString())}})); - - return Status::OK(); - }); - EXPECT_CALL(sql_client_, DoAction(_, _, _)).Times(2); - - auto buffer = Buffer::FromString(result.SerializeAsString()); - ON_CALL(sql_client_, DoPut) - .WillByDefault([&buffer](const FlightCallOptions& options, - const FlightDescriptor& descriptor1, - const std::shared_ptr& schema, - std::unique_ptr* writer, - std::unique_ptr* reader) { - reader->reset(new FlightMetadataReaderMock(&buffer)); - writer->reset(new FlightStreamWriterMock()); - return Status::OK(); - }); - if (schema == NULLPTR) { - EXPECT_CALL(sql_client_, DoPut(_, _, _, _, _)); - } else { - EXPECT_CALL(sql_client_, DoPut(_, _, *schema, _, _)); - } - - ASSERT_OK_AND_ASSIGN(auto prepared_statement, sql_client_.Prepare({}, query)); - func(prepared_statement, sql_client_, schema, expected_rows); - ASSERT_OK_AND_ASSIGN(auto rows, prepared_statement->ExecuteUpdate()); - ASSERT_EQ(expected_rows, rows); - ASSERT_OK(prepared_statement->Close()); -} - -TEST_F(TestFlightSqlClient, TestPreparedStatementExecuteUpdateNoParameterBinding) { - AssertTestPreparedStatementExecuteUpdateOk( - [](const std::shared_ptr& prepared_statement, - FlightSqlClient& sql_client_, const std::shared_ptr* schema, - const int64_t& row_count) {}, - NULLPTR, sql_client_); -} - -TEST_F(TestFlightSqlClient, TestPreparedStatementExecuteUpdateWithParameterBinding) { - const auto schema = arrow::schema( - {arrow::field("field0", arrow::utf8()), arrow::field("field1", arrow::uint8())}); - AssertTestPreparedStatementExecuteUpdateOk( - [](const std::shared_ptr& prepared_statement, - FlightSqlClient& sql_client_, const std::shared_ptr* schema, - const int64_t& row_count) { - auto string_array = - ArrayFromJSON(utf8(), R"(["Lorem", "Ipsum", "Foo", "Bar", "Baz"])"); - auto uint8_array = ArrayFromJSON(uint8(), R"([0, 10, 15, 20, 25])"); - std::shared_ptr recordBatch = - RecordBatch::Make(*schema, row_count, {string_array, uint8_array}); - ASSERT_OK(prepared_statement->SetParameters(recordBatch)); - }, - &schema, sql_client_); -} - -} // namespace sql -} // namespace flight -} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/example/sqlite_server.cc b/cpp/src/arrow/flight/sql/example/sqlite_server.cc index 9997d970a6e..a02f825a9e6 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_server.cc +++ b/cpp/src/arrow/flight/sql/example/sqlite_server.cc @@ -17,7 +17,6 @@ #include "arrow/flight/sql/example/sqlite_server.h" -#include #define BOOST_NO_CXX98_FUNCTION_BASE // ARROW-17805 #include #include @@ -26,24 +25,32 @@ #include #include +#include + +#include "arrow/array/builder_binary.h" #include "arrow/flight/sql/example/sqlite_sql_info.h" #include "arrow/flight/sql/example/sqlite_statement.h" #include "arrow/flight/sql/example/sqlite_statement_batch_reader.h" #include "arrow/flight/sql/example/sqlite_tables_schema_batch_reader.h" #include "arrow/flight/sql/example/sqlite_type_info.h" #include "arrow/flight/sql/server.h" +#include "arrow/scalar.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/logging.h" namespace arrow { namespace flight { namespace sql { namespace example { +using arrow::internal::checked_cast; + namespace { std::string PrepareQueryForGetTables(const GetTables& command) { std::stringstream table_query; - table_query << "SELECT null as catalog_name, null as schema_name, name as " + table_query << "SELECT 'main' as catalog_name, null as schema_name, name as " "table_name, type as table_type FROM sqlite_master where 1=1"; if (command.catalog.has_value()) { @@ -77,50 +84,27 @@ std::string PrepareQueryForGetTables(const GetTables& command) { return table_query.str(); } -Status SetParametersOnSQLiteStatement(sqlite3_stmt* stmt, FlightMessageReader* reader) { +template +Status SetParametersOnSQLiteStatement(SqliteStatement* statement, + FlightMessageReader* reader, Callback callback) { + sqlite3_stmt* stmt = statement->GetSqlite3Stmt(); while (true) { ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, reader->Next()); - std::shared_ptr& record_batch = chunk.data; - if (record_batch == nullptr) break; + if (chunk.data == nullptr) break; - const int64_t num_rows = record_batch->num_rows(); - const int& num_columns = record_batch->num_columns(); + const int64_t num_rows = chunk.data->num_rows(); + if (num_rows == 0) continue; + ARROW_RETURN_NOT_OK(statement->SetParameters({std::move(chunk.data)})); for (int i = 0; i < num_rows; ++i) { - for (int c = 0; c < num_columns; ++c) { - const std::shared_ptr& column = record_batch->column(c); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, column->GetScalar(i)); - - auto& holder = static_cast(*scalar).value; - - switch (holder->type->id()) { - case Type::INT64: { - int64_t value = static_cast(*holder).value; - sqlite3_bind_int64(stmt, c + 1, value); - break; - } - case Type::FLOAT: { - double value = static_cast(*holder).value; - sqlite3_bind_double(stmt, c + 1, value); - break; - } - case Type::STRING: { - std::shared_ptr buffer = static_cast(*holder).value; - sqlite3_bind_text(stmt, c + 1, reinterpret_cast(buffer->data()), - static_cast(buffer->size()), SQLITE_TRANSIENT); - break; - } - case Type::BINARY: { - std::shared_ptr buffer = static_cast(*holder).value; - sqlite3_bind_blob(stmt, c + 1, buffer->data(), - static_cast(buffer->size()), SQLITE_TRANSIENT); - break; - } - default: - return Status::Invalid("Received unsupported data type: ", - holder->type->ToString()); - } + if (sqlite3_clear_bindings(stmt) != SQLITE_OK) { + return Status::Invalid("Failed to reset bindings on row ", i, ": ", + sqlite3_errmsg(statement->db())); } + // batch_index is always 0 since we're calling SetParameters + // with a single batch at a time + ARROW_RETURN_NOT_OK(statement->Bind(/*batch_index=*/0, i)); + ARROW_RETURN_NOT_OK(callback()); } } @@ -183,8 +167,8 @@ std::string PrepareQueryForGetImportedOrExportedKeys(const std::string& filter) } // namespace -std::shared_ptr GetArrowType(const char* sqlite_type) { - if (sqlite_type == NULLPTR) { +arrow::Result> GetArrowType(const char* sqlite_type) { + if (sqlite_type == nullptr || std::strlen(sqlite_type) == 0) { // SQLite may not know the column type yet. return null(); } @@ -199,9 +183,8 @@ std::shared_ptr GetArrowType(const char* sqlite_type) { boost::istarts_with(sqlite_type, "char") || boost::istarts_with(sqlite_type, "varchar")) { return utf8(); - } else { - throw std::invalid_argument("Invalid SQLite type: " + std::string(sqlite_type)); } + return Status::Invalid("Invalid SQLite type: ", sqlite_type); } int32_t GetSqlTypeFromTypeName(const char* sqlite_type) { @@ -245,13 +228,17 @@ class SQLiteFlightSqlServer::Impl { } arrow::Result GetConnection(const std::string& transaction_id) { - if (transaction_id.empty()) return db_; + if (transaction_id.empty()) { + ARROW_LOG(INFO) << "Using default connection"; + return db_; + } std::lock_guard guard(mutex_); auto it = open_transactions_.find(transaction_id); if (it == open_transactions_.end()) { return Status::KeyError("Unknown transaction ID: ", transaction_id); } + ARROW_LOG(INFO) << "Using connection for transaction " << transaction_id; return it->second; } @@ -346,18 +333,19 @@ class SQLiteFlightSqlServer::Impl { arrow::Result> DoGetCatalogs( const ServerCallContext& context) { - // As SQLite doesn't support catalogs, this will return an empty record batch. - + // https://www.sqlite.org/cli.html + // > The ".databases" command shows a list of all databases open + // > in the current connection. There will always be at least + // > 2. The first one is "main", the original database opened. The + // > second is "temp", the database used for temporary tables. + // For our purposes, return only "main" and ignore other databases. const std::shared_ptr& schema = SqlSchema::GetCatalogsSchema(); - StringBuilder catalog_name_builder; + ARROW_RETURN_NOT_OK(catalog_name_builder.Append("main")); ARROW_ASSIGN_OR_RAISE(auto catalog_name, catalog_name_builder.Finish()); - - const std::shared_ptr& batch = - RecordBatch::Make(schema, 0, {catalog_name}); - + std::shared_ptr batch = + RecordBatch::Make(schema, 1, {std::move(catalog_name)}); ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({batch})); - return std::make_unique(reader); } @@ -369,20 +357,26 @@ class SQLiteFlightSqlServer::Impl { arrow::Result> DoGetDbSchemas( const ServerCallContext& context, const GetDbSchemas& command) { - // As SQLite doesn't support schemas, this will return an empty record batch. - + // SQLite doesn't support schemas, so pretend we have a single + // unnamed schema. const std::shared_ptr& schema = SqlSchema::GetDbSchemasSchema(); - StringBuilder catalog_name_builder; - ARROW_ASSIGN_OR_RAISE(auto catalog_name, catalog_name_builder.Finish()); StringBuilder schema_name_builder; - ARROW_ASSIGN_OR_RAISE(auto schema_name, schema_name_builder.Finish()); - const std::shared_ptr& batch = - RecordBatch::Make(schema, 0, {catalog_name, schema_name}); + int64_t length = 0; + // XXX: we don't really implement the full pattern match here + if ((!command.catalog || command.catalog == "main") && + (!command.db_schema_filter_pattern || command.db_schema_filter_pattern == "%")) { + ARROW_RETURN_NOT_OK(catalog_name_builder.Append("main")); + ARROW_RETURN_NOT_OK(schema_name_builder.AppendNull()); + length++; + } + ARROW_ASSIGN_OR_RAISE(auto catalog_name, catalog_name_builder.Finish()); + ARROW_ASSIGN_OR_RAISE(auto schema_name, schema_name_builder.Finish()); + std::shared_ptr batch = RecordBatch::Make( + schema, length, {std::move(catalog_name), std::move(schema_name)}); ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({batch})); - return std::make_unique(reader); } @@ -392,6 +386,7 @@ class SQLiteFlightSqlServer::Impl { std::vector endpoints{FlightEndpoint{{descriptor.cmd}, {}}}; bool include_schema = command.include_schema; + ARROW_LOG(INFO) << "GetTables include_schema=" << include_schema; ARROW_ASSIGN_OR_RAISE( auto result, @@ -399,12 +394,13 @@ class SQLiteFlightSqlServer::Impl { : *SqlSchema::GetTablesSchema(), descriptor, endpoints, -1, -1)) - return std::make_unique(result); + return std::make_unique(std::move(result)); } arrow::Result> DoGetTables( const ServerCallContext& context, const GetTables& command) { std::string query = PrepareQueryForGetTables(command); + ARROW_LOG(INFO) << "GetTables: " << query; std::shared_ptr statement; ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db_, query)); @@ -426,6 +422,7 @@ class SQLiteFlightSqlServer::Impl { const StatementUpdate& command) { const std::string& sql = command.query; ARROW_ASSIGN_OR_RAISE(auto db, GetConnection(command.transaction_id)); + ARROW_LOG(INFO) << "Executing update: " << sql; ARROW_ASSIGN_OR_RAISE(auto statement, SqliteStatement::Create(db, sql)); return statement->ExecuteUpdate(); } @@ -433,11 +430,16 @@ class SQLiteFlightSqlServer::Impl { arrow::Result CreatePreparedStatement( const ServerCallContext& context, const ActionCreatePreparedStatementRequest& request) { - std::lock_guard guard(mutex_); std::shared_ptr statement; - ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db_, request.query)); + ARROW_ASSIGN_OR_RAISE(auto db, GetConnection(request.transaction_id)); + ARROW_LOG(INFO) << "Creating prepared statement: " << request.query; + ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db, request.query)); std::string handle = GenerateRandomString(); - prepared_statements_[handle] = statement; + + { + std::lock_guard guard(mutex_); + prepared_statements_[handle] = statement; + } ARROW_ASSIGN_OR_RAISE(auto dataset_schema, statement->GetSchema()); @@ -525,10 +527,9 @@ class SQLiteFlightSqlServer::Impl { const std::string& prepared_statement_handle = command.prepared_statement_handle; ARROW_ASSIGN_OR_RAISE(auto statement, GetStatementByHandle(prepared_statement_handle)); - - sqlite3_stmt* stmt = statement->GetSqlite3Stmt(); - ARROW_RETURN_NOT_OK(SetParametersOnSQLiteStatement(stmt, reader)); - + // Save params here and execute later + ARROW_ASSIGN_OR_RAISE(auto batches, reader->ToRecordBatches()); + ARROW_RETURN_NOT_OK(statement->SetParameters(std::move(batches))); return Status::OK(); } @@ -536,13 +537,20 @@ class SQLiteFlightSqlServer::Impl { const ServerCallContext& context, const PreparedStatementUpdate& command, FlightMessageReader* reader) { const std::string& prepared_statement_handle = command.prepared_statement_handle; - ARROW_ASSIGN_OR_RAISE(auto statement, + ARROW_ASSIGN_OR_RAISE(std::shared_ptr statement, GetStatementByHandle(prepared_statement_handle)); - sqlite3_stmt* stmt = statement->GetSqlite3Stmt(); - ARROW_RETURN_NOT_OK(SetParametersOnSQLiteStatement(stmt, reader)); - - return statement->ExecuteUpdate(); + int64_t rows_affected = 0; + if (sqlite3_bind_parameter_count(statement->GetSqlite3Stmt()) == 0) { + ARROW_ASSIGN_OR_RAISE(rows_affected, statement->ExecuteUpdate()); + } else { + ARROW_RETURN_NOT_OK(SetParametersOnSQLiteStatement(statement.get(), reader, [&]() { + ARROW_ASSIGN_OR_RAISE(int64_t rows, statement->ExecuteUpdate()); + rows_affected += rows; + return statement->Reset().status(); + })); + } + return rows_affected; } arrow::Result> GetFlightInfoTableTypes( @@ -712,6 +720,8 @@ class SQLiteFlightSqlServer::Impl { ARROW_RETURN_NOT_OK(ExecuteSql(new_db, "BEGIN TRANSACTION")); + ARROW_LOG(INFO) << "Beginning transaction on " << handle; + std::lock_guard guard(mutex_); open_transactions_[handle] = new_db; return ActionBeginTransactionResult{std::move(handle)}; @@ -729,8 +739,10 @@ class SQLiteFlightSqlServer::Impl { } if (request.action == ActionEndTransactionRequest::kCommit) { + ARROW_LOG(INFO) << "Committing on " << request.transaction_id; status = ExecuteSql(it->second, "COMMIT"); } else { + ARROW_LOG(INFO) << "Rolling back on " << request.transaction_id; status = ExecuteSql(it->second, "ROLLBACK"); } transaction = it->second; @@ -795,6 +807,7 @@ arrow::Result> SQLiteFlightSqlServer::Cre INSERT INTO intTable (keyName, value, foreignId) VALUES ('zero', 0, 1); INSERT INTO intTable (keyName, value, foreignId) VALUES ('negative one', -1, 1); INSERT INTO intTable (keyName, value, foreignId) VALUES (NULL, NULL, NULL); + INSERT INTO intTable (keyName, value, foreignId) VALUES ('null', NULL, NULL); )")); return result; diff --git a/cpp/src/arrow/flight/sql/example/sqlite_server.h b/cpp/src/arrow/flight/sql/example/sqlite_server.h index 389a2d921bb..d8c84e36e68 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_server.h +++ b/cpp/src/arrow/flight/sql/example/sqlite_server.h @@ -19,13 +19,14 @@ #include +#include #include #include -#include "arrow/api.h" #include "arrow/flight/sql/example/sqlite_statement.h" #include "arrow/flight/sql/example/sqlite_statement_batch_reader.h" #include "arrow/flight/sql/server.h" +#include "arrow/result.h" namespace arrow { namespace flight { @@ -35,7 +36,7 @@ namespace example { /// \brief Convert a column type to a ArrowType. /// \param sqlite_type the sqlite type. /// \return The equivalent ArrowType. -std::shared_ptr GetArrowType(const char* sqlite_type); +arrow::Result> GetArrowType(const char* sqlite_type); /// \brief Convert a column type name to SQLite type. /// \param type_name the type name. diff --git a/cpp/src/arrow/flight/sql/example/sqlite_statement.cc b/cpp/src/arrow/flight/sql/example/sqlite_statement.cc index d0810e2db02..23639256600 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_statement.cc +++ b/cpp/src/arrow/flight/sql/example/sqlite_statement.cc @@ -17,19 +17,28 @@ #include "arrow/flight/sql/example/sqlite_statement.h" -#include +#include -#define BOOST_NO_CXX98_FUNCTION_BASE // ARROW-17805 -#include +#include +#include "arrow/array/array_base.h" +#include "arrow/array/array_binary.h" +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" #include "arrow/flight/sql/column_metadata.h" #include "arrow/flight/sql/example/sqlite_server.h" +#include "arrow/scalar.h" +#include "arrow/table.h" +#include "arrow/type.h" +#include "arrow/util/checked_cast.h" namespace arrow { namespace flight { namespace sql { namespace example { +using arrow::internal::checked_cast; + std::shared_ptr GetDataTypeFromSqliteType(const int column_type) { switch (column_type) { case SQLITE_INTEGER: @@ -119,7 +128,7 @@ arrow::Result> SqliteStatement::GetSchema() const { // Try to retrieve column type from sqlite3_column_decltype const char* column_decltype = sqlite3_column_decltype(stmt_, i); if (column_decltype != NULLPTR) { - data_type = GetArrowType(column_decltype); + ARROW_ASSIGN_OR_RAISE(data_type, GetArrowType(column_decltype)); } else { // If it can not determine the actual column type, return a dense_union type // covering any type SQLite supports. @@ -160,10 +169,100 @@ arrow::Result SqliteStatement::Reset() { sqlite3_stmt* SqliteStatement::GetSqlite3Stmt() const { return stmt_; } arrow::Result SqliteStatement::ExecuteUpdate() { - ARROW_RETURN_NOT_OK(Step()); + while (true) { + ARROW_ASSIGN_OR_RAISE(int rc, Step()); + if (rc == SQLITE_DONE) break; + } return sqlite3_changes(db_); } +Status SqliteStatement::SetParameters( + std::vector> parameters) { + const int num_params = sqlite3_bind_parameter_count(stmt_); + for (const auto& batch : parameters) { + if (batch->num_columns() != num_params) { + return Status::Invalid("Expected ", num_params, " parameters, but got ", + batch->num_columns()); + } + } + parameters_ = std::move(parameters); + auto end = std::remove_if( + parameters_.begin(), parameters_.end(), + [](const std::shared_ptr& batch) { return batch->num_rows() == 0; }); + parameters_.erase(end, parameters_.end()); + return Status::OK(); +} + +Status SqliteStatement::Bind(size_t batch_index, int64_t row_index) { + if (batch_index >= parameters_.size()) { + return Status::IndexError("Cannot bind to batch ", batch_index); + } + const RecordBatch& batch = *parameters_[batch_index]; + if (row_index < 0 || row_index >= batch.num_rows()) { + return Status::IndexError("Cannot bind to row ", row_index, " in batch ", + batch_index); + } + + if (sqlite3_clear_bindings(stmt_) != SQLITE_OK) { + return Status::Invalid("Failed to reset bindings: ", sqlite3_errmsg(db_)); + } + for (int c = 0; c < batch.num_columns(); ++c) { + Array* column = batch.column(c).get(); + int64_t column_index = row_index; + if (column->type_id() == Type::DENSE_UNION) { + // Allow polymorphic bindings via union + const auto& u = checked_cast(*column); + column_index = u.value_offset(column_index); + column = u.field(u.child_id(row_index)).get(); + } + + int rc = 0; + if (column->IsNull(column_index)) { + rc = sqlite3_bind_null(stmt_, c + 1); + continue; + } + switch (column->type_id()) { + case Type::INT32: { + const int32_t value = + checked_cast(*column).Value(column_index); + rc = sqlite3_bind_int64(stmt_, c + 1, value); + break; + } + case Type::INT64: { + const int64_t value = + checked_cast(*column).Value(column_index); + rc = sqlite3_bind_int64(stmt_, c + 1, value); + break; + } + case Type::FLOAT: { + const float value = checked_cast(*column).Value(column_index); + rc = sqlite3_bind_double(stmt_, c + 1, value); + break; + } + case Type::DOUBLE: { + const double value = + checked_cast(*column).Value(column_index); + rc = sqlite3_bind_double(stmt_, c + 1, value); + break; + } + case Type::STRING: { + const std::string_view value = + checked_cast(*column).Value(column_index); + rc = sqlite3_bind_text(stmt_, c + 1, value.data(), static_cast(value.size()), + SQLITE_TRANSIENT); + break; + } + default: + return Status::TypeError("Received unsupported data type: ", *column->type()); + } + if (rc != SQLITE_OK) { + return Status::UnknownError("Failed to bind parameter: ", sqlite3_errmsg(db_)); + } + } + + return Status::OK(); +} + } // namespace example } // namespace sql } // namespace flight diff --git a/cpp/src/arrow/flight/sql/example/sqlite_statement.h b/cpp/src/arrow/flight/sql/example/sqlite_statement.h index b31eab506fa..333a2d24457 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_statement.h +++ b/cpp/src/arrow/flight/sql/example/sqlite_statement.h @@ -64,13 +64,22 @@ class SqliteStatement { /// \return A sqlite statement. sqlite3_stmt* GetSqlite3Stmt() const; + sqlite3* db() const { return db_; } + /// \brief Executes an UPDATE, INSERT or DELETE statement. /// \return The number of rows changed by execution. arrow::Result ExecuteUpdate(); + const std::vector>& parameters() const { + return parameters_; + } + Status SetParameters(std::vector> parameters); + Status Bind(size_t batch_index, int64_t row_index); + private: sqlite3* db_; sqlite3_stmt* stmt_; + std::vector> parameters_; SqliteStatement(sqlite3* db, sqlite3_stmt* stmt) : db_(db), stmt_(stmt) {} }; diff --git a/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.cc b/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.cc index c247eb62875..27c72614c5d 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.cc +++ b/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.cc @@ -54,10 +54,6 @@ case TYPE_CLASS##Type::type_id: { \ using c_type = typename TYPE_CLASS##Type::c_type; \ auto builder = reinterpret_cast(array_builder); \ - if (sqlite3_column_type(stmt_, i) == SQLITE_NULL) { \ - ARROW_RETURN_NOT_OK(builder->AppendNull()); \ - break; \ - } \ const sqlite3_int64 value = sqlite3_column_int64(STMT, COLUMN); \ ARROW_RETURN_NOT_OK(builder->Append(static_cast(value))); \ break; \ @@ -66,10 +62,6 @@ #define FLOAT_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \ case TYPE_CLASS##Type::type_id: { \ auto builder = reinterpret_cast(array_builder); \ - if (sqlite3_column_type(stmt_, i) == SQLITE_NULL) { \ - ARROW_RETURN_NOT_OK(builder->AppendNull()); \ - break; \ - } \ const double value = sqlite3_column_double(STMT, COLUMN); \ ARROW_RETURN_NOT_OK( \ builder->Append(static_cast(value))); \ @@ -82,7 +74,7 @@ namespace sql { namespace example { // Batch size for SQLite statement results -static constexpr int kMaxBatchSize = 1024; +static constexpr int32_t kMaxBatchSize = 16384; std::shared_ptr SqliteStatementBatchReader::schema() const { return schema_; } @@ -95,8 +87,12 @@ SqliteStatementBatchReader::SqliteStatementBatchReader( arrow::Result> SqliteStatementBatchReader::Create(const std::shared_ptr& statement_) { + ARROW_RETURN_NOT_OK(statement_->Reset()); + if (!statement_->parameters().empty()) { + // If there are parameters, infer the schema after binding the first row + ARROW_RETURN_NOT_OK(statement_->Bind(0, 0)); + } ARROW_RETURN_NOT_OK(statement_->Step()); - ARROW_ASSIGN_OR_RAISE(auto schema, statement_->GetSchema()); std::shared_ptr result( @@ -108,10 +104,8 @@ SqliteStatementBatchReader::Create(const std::shared_ptr& state arrow::Result> SqliteStatementBatchReader::Create(const std::shared_ptr& statement, const std::shared_ptr& schema) { - std::shared_ptr result( + return std::shared_ptr( new SqliteStatementBatchReader(statement, schema)); - - return result; } Status SqliteStatementBatchReader::ReadNext(std::shared_ptr* out) { @@ -127,61 +121,89 @@ Status SqliteStatementBatchReader::ReadNext(std::shared_ptr* out) { ARROW_RETURN_NOT_OK(MakeBuilder(default_memory_pool(), field_type, &builders[i])); } - if (!already_executed_) { - ARROW_ASSIGN_OR_RAISE(rc_, statement_->Reset()); - ARROW_ASSIGN_OR_RAISE(rc_, statement_->Step()); - already_executed_ = true; - } - int64_t rows = 0; - while (rows < kMaxBatchSize && rc_ == SQLITE_ROW) { - rows++; - for (int i = 0; i < num_fields; i++) { - const std::shared_ptr& field = schema_->field(i); - const std::shared_ptr& field_type = field->type(); - ArrayBuilder* array_builder = builders[i].get(); - - // NOTE: This is not the optimal way of building Arrow vectors. - // That would be to presize the builders to avoiding several resizing operations - // when appending values and also to build one vector at a time. - switch (field_type->id()) { - // XXX This doesn't handle overflows when converting to the target - // integer type. - INT_BUILDER_CASE(Int64, stmt_, i) - INT_BUILDER_CASE(UInt64, stmt_, i) - INT_BUILDER_CASE(Int32, stmt_, i) - INT_BUILDER_CASE(UInt32, stmt_, i) - INT_BUILDER_CASE(Int16, stmt_, i) - INT_BUILDER_CASE(UInt16, stmt_, i) - INT_BUILDER_CASE(Int8, stmt_, i) - INT_BUILDER_CASE(UInt8, stmt_, i) - FLOAT_BUILDER_CASE(Double, stmt_, i) - FLOAT_BUILDER_CASE(Float, stmt_, i) - FLOAT_BUILDER_CASE(HalfFloat, stmt_, i) - BINARY_BUILDER_CASE(Binary, stmt_, i) - BINARY_BUILDER_CASE(LargeBinary, stmt_, i) - STRING_BUILDER_CASE(String, stmt_, i) - STRING_BUILDER_CASE(LargeString, stmt_, i) - default: - return Status::NotImplemented("Not implemented SQLite data conversion to ", - field_type->name()); + while (true) { + if (!already_executed_) { + ARROW_ASSIGN_OR_RAISE(rc_, statement_->Reset()); + if (!statement_->parameters().empty()) { + if (batch_index_ >= statement_->parameters().size()) { + *out = nullptr; + break; + } + ARROW_RETURN_NOT_OK(statement_->Bind(batch_index_, row_index_)); } + ARROW_ASSIGN_OR_RAISE(rc_, statement_->Step()); + already_executed_ = true; } - ARROW_ASSIGN_OR_RAISE(rc_, statement_->Step()); - } + while (rows < kMaxBatchSize && rc_ == SQLITE_ROW) { + rows++; + for (int i = 0; i < num_fields; i++) { + const std::shared_ptr& field = schema_->field(i); + const std::shared_ptr& field_type = field->type(); + ArrayBuilder* array_builder = builders[i].get(); + + if (sqlite3_column_type(stmt_, i) == SQLITE_NULL) { + ARROW_RETURN_NOT_OK(array_builder->AppendNull()); + continue; + } + + switch (field_type->id()) { + // XXX This doesn't handle overflows when converting to the target + // integer type. + INT_BUILDER_CASE(Int64, stmt_, i) + INT_BUILDER_CASE(UInt64, stmt_, i) + INT_BUILDER_CASE(Int32, stmt_, i) + INT_BUILDER_CASE(UInt32, stmt_, i) + INT_BUILDER_CASE(Int16, stmt_, i) + INT_BUILDER_CASE(UInt16, stmt_, i) + INT_BUILDER_CASE(Int8, stmt_, i) + INT_BUILDER_CASE(UInt8, stmt_, i) + FLOAT_BUILDER_CASE(Double, stmt_, i) + FLOAT_BUILDER_CASE(Float, stmt_, i) + FLOAT_BUILDER_CASE(HalfFloat, stmt_, i) + BINARY_BUILDER_CASE(Binary, stmt_, i) + BINARY_BUILDER_CASE(LargeBinary, stmt_, i) + STRING_BUILDER_CASE(String, stmt_, i) + STRING_BUILDER_CASE(LargeString, stmt_, i) + default: + return Status::NotImplemented("Not implemented SQLite data conversion to ", + field_type->name()); + } + } - if (rows > 0) { - std::vector> arrays(builders.size()); - for (int i = 0; i < num_fields; i++) { - ARROW_RETURN_NOT_OK(builders[i]->Finish(&arrays[i])); + ARROW_ASSIGN_OR_RAISE(rc_, statement_->Step()); } - *out = RecordBatch::Make(schema_, rows, arrays); - } else { - *out = NULLPTR; - } + // If we still have bind parameters, bind again and retry + const std::vector>& params = statement_->parameters(); + if (!params.empty() && rc_ == SQLITE_DONE && batch_index_ < params.size()) { + row_index_++; + if (row_index_ < params[batch_index_]->num_rows()) { + already_executed_ = false; + } else { + batch_index_++; + row_index_ = 0; + if (batch_index_ < params.size()) { + already_executed_ = false; + } + } + + if (!already_executed_ && rows < kMaxBatchSize) continue; + } + if (rows > 0) { + std::vector> arrays(builders.size()); + for (int i = 0; i < num_fields; i++) { + ARROW_RETURN_NOT_OK(builders[i]->Finish(&arrays[i])); + } + + *out = RecordBatch::Make(schema_, rows, arrays); + } else { + *out = nullptr; + } + break; + } return Status::OK(); } diff --git a/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.h b/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.h index 8a6bc6078e7..3fb9ae1f83c 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.h +++ b/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.h @@ -55,6 +55,10 @@ class SqliteStatementBatchReader : public RecordBatchReader { int rc_; bool already_executed_; + // State for parameter binding + size_t batch_index_{0}; + int64_t row_index_{0}; + SqliteStatementBatchReader(std::shared_ptr statement, std::shared_ptr schema); }; diff --git a/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.cc b/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.cc index 921dd13182e..55345ad477a 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.cc +++ b/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.cc @@ -21,6 +21,7 @@ #include +#include "arrow/array/builder_binary.h" #include "arrow/flight/sql/column_metadata.h" #include "arrow/flight/sql/example/sqlite_server.h" #include "arrow/flight/sql/example/sqlite_statement.h" @@ -80,16 +81,19 @@ Status SqliteTablesWithSchemaBatchReader::ReadNext(std::shared_ptr* const ColumnMetadata& column_metadata = GetColumnMetadata( GetSqlTypeFromTypeName(column_type), sqlite_table_name.c_str()); - column_fields.push_back(arrow::field(column_name, GetArrowType(column_type), - nullable == 0, + std::shared_ptr arrow_type; + auto status = GetArrowType(column_type).Value(&arrow_type); + if (!status.ok()) { + return Status::NotImplemented("Unknown SQLite type '", column_type, + "' for column '", column_name, "' in table '", + table_name, "': ", status); + } + column_fields.push_back(arrow::field(column_name, arrow_type, nullable == 0, column_metadata.metadata_map())); } } - const arrow::Result>& value = - ipc::SerializeSchema(*arrow::schema(column_fields)); - - std::shared_ptr schema_buffer; - ARROW_ASSIGN_OR_RAISE(schema_buffer, value); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr schema_buffer, + ipc::SerializeSchema(*arrow::schema(column_fields))); column_fields.clear(); ARROW_RETURN_NOT_OK(schema_builder.Append(::std::string_view(*schema_buffer))); diff --git a/cpp/src/arrow/flight/sql/server.cc b/cpp/src/arrow/flight/sql/server.cc index 80112c2a44d..7f6d9b75a88 100644 --- a/cpp/src/arrow/flight/sql/server.cc +++ b/cpp/src/arrow/flight/sql/server.cc @@ -1094,41 +1094,49 @@ arrow::Result FlightSqlServerBase::DoPutCommandSubstraitPlan( return Status::NotImplemented("DoPutCommandSubstraitPlan not implemented"); } -std::shared_ptr SqlSchema::GetCatalogsSchema() { - return arrow::schema({field("catalog_name", utf8(), false)}); +const std::shared_ptr& SqlSchema::GetCatalogsSchema() { + static std::shared_ptr kSchema = + arrow::schema({field("catalog_name", utf8(), false)}); + return kSchema; } -std::shared_ptr SqlSchema::GetDbSchemasSchema() { - return arrow::schema( +const std::shared_ptr& SqlSchema::GetDbSchemasSchema() { + static std::shared_ptr kSchema = arrow::schema( {field("catalog_name", utf8()), field("db_schema_name", utf8(), false)}); + return kSchema; } -std::shared_ptr SqlSchema::GetTablesSchema() { - return arrow::schema({field("catalog_name", utf8()), field("db_schema_name", utf8()), - field("table_name", utf8(), false), - field("table_type", utf8(), false)}); +const std::shared_ptr& SqlSchema::GetTablesSchema() { + static std::shared_ptr kSchema = arrow::schema( + {field("catalog_name", utf8()), field("db_schema_name", utf8()), + field("table_name", utf8(), false), field("table_type", utf8(), false)}); + return kSchema; } -std::shared_ptr SqlSchema::GetTablesSchemaWithIncludedSchema() { - return arrow::schema({field("catalog_name", utf8()), field("db_schema_name", utf8()), - field("table_name", utf8(), false), - field("table_type", utf8(), false), - field("table_schema", binary(), false)}); +const std::shared_ptr& SqlSchema::GetTablesSchemaWithIncludedSchema() { + static std::shared_ptr kSchema = arrow::schema( + {field("catalog_name", utf8()), field("db_schema_name", utf8()), + field("table_name", utf8(), false), field("table_type", utf8(), false), + field("table_schema", binary(), false)}); + return kSchema; } -std::shared_ptr SqlSchema::GetTableTypesSchema() { - return arrow::schema({field("table_type", utf8(), false)}); +const std::shared_ptr& SqlSchema::GetTableTypesSchema() { + static std::shared_ptr kSchema = + arrow::schema({field("table_type", utf8(), false)}); + return kSchema; } -std::shared_ptr SqlSchema::GetPrimaryKeysSchema() { - return arrow::schema( +const std::shared_ptr& SqlSchema::GetPrimaryKeysSchema() { + static std::shared_ptr kSchema = arrow::schema( {field("catalog_name", utf8()), field("db_schema_name", utf8()), field("table_name", utf8(), false), field("column_name", utf8(), false), field("key_sequence", int32(), false), field("key_name", utf8())}); + return kSchema; } -std::shared_ptr GetImportedExportedKeysAndCrossReferenceSchema() { - return arrow::schema( +const std::shared_ptr& GetImportedExportedKeysAndCrossReferenceSchema() { + static std::shared_ptr kSchema = arrow::schema( {field("pk_catalog_name", utf8(), true), field("pk_db_schema_name", utf8(), true), field("pk_table_name", utf8(), false), field("pk_column_name", utf8(), false), field("fk_catalog_name", utf8(), true), field("fk_db_schema_name", utf8(), true), @@ -1136,35 +1144,44 @@ std::shared_ptr GetImportedExportedKeysAndCrossReferenceSchema() { field("key_sequence", int32(), false), field("fk_key_name", utf8(), true), field("pk_key_name", utf8(), true), field("update_rule", uint8(), false), field("delete_rule", uint8(), false)}); + return kSchema; } -std::shared_ptr SqlSchema::GetImportedKeysSchema() { - return GetImportedExportedKeysAndCrossReferenceSchema(); +const std::shared_ptr& SqlSchema::GetImportedKeysSchema() { + static std::shared_ptr kSchema = + GetImportedExportedKeysAndCrossReferenceSchema(); + return kSchema; } -std::shared_ptr SqlSchema::GetExportedKeysSchema() { - return GetImportedExportedKeysAndCrossReferenceSchema(); +const std::shared_ptr& SqlSchema::GetExportedKeysSchema() { + static std::shared_ptr kSchema = + GetImportedExportedKeysAndCrossReferenceSchema(); + return kSchema; } -std::shared_ptr SqlSchema::GetCrossReferenceSchema() { - return GetImportedExportedKeysAndCrossReferenceSchema(); +const std::shared_ptr& SqlSchema::GetCrossReferenceSchema() { + static std::shared_ptr kSchema = + GetImportedExportedKeysAndCrossReferenceSchema(); + return kSchema; } -std::shared_ptr SqlSchema::GetSqlInfoSchema() { - return arrow::schema({field("info_name", uint32(), false), - field("value", - dense_union({field("string_value", utf8(), false), - field("bool_value", boolean(), false), - field("bigint_value", int64(), false), - field("int32_bitmask", int32(), false), - field("string_list", list(utf8()), false), - field("int32_to_int32_list_map", - map(int32(), list(int32())), false)}), - false)}); +const std::shared_ptr& SqlSchema::GetSqlInfoSchema() { + static std::shared_ptr kSchema = + arrow::schema({field("info_name", uint32(), false), + field("value", + dense_union({field("string_value", utf8(), false), + field("bool_value", boolean(), false), + field("bigint_value", int64(), false), + field("int32_bitmask", int32(), false), + field("string_list", list(utf8()), false), + field("int32_to_int32_list_map", + map(int32(), list(int32())), false)}), + false)}); + return kSchema; } -std::shared_ptr SqlSchema::GetXdbcTypeInfoSchema() { - return arrow::schema({ +const std::shared_ptr& SqlSchema::GetXdbcTypeInfoSchema() { + static std::shared_ptr kSchema = arrow::schema({ field("type_name", utf8(), false), field("data_type", int32(), false), field("column_size", int32()), @@ -1185,6 +1202,7 @@ std::shared_ptr SqlSchema::GetXdbcTypeInfoSchema() { field("num_prec_radix", int32()), field("interval_precision", int32()), }); + return kSchema; } } // namespace sql } // namespace flight diff --git a/cpp/src/arrow/flight/sql/server.h b/cpp/src/arrow/flight/sql/server.h index 0fc8b714865..65f6670171d 100644 --- a/cpp/src/arrow/flight/sql/server.h +++ b/cpp/src/arrow/flight/sql/server.h @@ -686,50 +686,50 @@ class ARROW_FLIGHT_SQL_EXPORT SqlSchema { public: /// \brief Get the Schema used on GetCatalogs response. /// \return The default schema template. - static std::shared_ptr GetCatalogsSchema(); + static const std::shared_ptr& GetCatalogsSchema(); /// \brief Get the Schema used on GetDbSchemas response. /// \return The default schema template. - static std::shared_ptr GetDbSchemasSchema(); + static const std::shared_ptr& GetDbSchemasSchema(); /// \brief Get the Schema used on GetTables response when included schema /// flags is set to false. /// \return The default schema template. - static std::shared_ptr GetTablesSchema(); + static const std::shared_ptr& GetTablesSchema(); /// \brief Get the Schema used on GetTables response when included schema /// flags is set to true. /// \return The default schema template. - static std::shared_ptr GetTablesSchemaWithIncludedSchema(); + static const std::shared_ptr& GetTablesSchemaWithIncludedSchema(); /// \brief Get the Schema used on GetTableTypes response. /// \return The default schema template. - static std::shared_ptr GetTableTypesSchema(); + static const std::shared_ptr& GetTableTypesSchema(); /// \brief Get the Schema used on GetPrimaryKeys response when included schema /// flags is set to true. /// \return The default schema template. - static std::shared_ptr GetPrimaryKeysSchema(); + static const std::shared_ptr& GetPrimaryKeysSchema(); /// \brief Get the Schema used on GetImportedKeys response. /// \return The default schema template. - static std::shared_ptr GetExportedKeysSchema(); + static const std::shared_ptr& GetExportedKeysSchema(); /// \brief Get the Schema used on GetImportedKeys response. /// \return The default schema template. - static std::shared_ptr GetImportedKeysSchema(); + static const std::shared_ptr& GetImportedKeysSchema(); /// \brief Get the Schema used on GetCrossReference response. /// \return The default schema template. - static std::shared_ptr GetCrossReferenceSchema(); + static const std::shared_ptr& GetCrossReferenceSchema(); /// \brief Get the Schema used on GetXdbcTypeInfo response. /// \return The default schema template. - static std::shared_ptr GetXdbcTypeInfoSchema(); + static const std::shared_ptr& GetXdbcTypeInfoSchema(); /// \brief Get the Schema used on GetSqlInfo response. /// \return The default schema template. - static std::shared_ptr GetSqlInfoSchema(); + static const std::shared_ptr& GetSqlInfoSchema(); }; } // namespace sql } // namespace flight diff --git a/cpp/src/arrow/flight/sql/server_test.cc b/cpp/src/arrow/flight/sql/server_test.cc index dc59b4a2c1c..0eedbda3033 100644 --- a/cpp/src/arrow/flight/sql/server_test.cc +++ b/cpp/src/arrow/flight/sql/server_test.cc @@ -15,34 +15,32 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/flight/sql/server.h" +#include -#include #include #include - -#include #include -#include "arrow/flight/api.h" -#include "arrow/flight/sql/api.h" +#include "arrow/array/array_binary.h" +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" +#include "arrow/flight/sql/client.h" #include "arrow/flight/sql/column_metadata.h" #include "arrow/flight/sql/example/sqlite_server.h" #include "arrow/flight/sql/example/sqlite_sql_info.h" #include "arrow/flight/sql/example/sqlite_type_info.h" +#include "arrow/flight/sql/server.h" #include "arrow/flight/test_util.h" #include "arrow/flight/types.h" +#include "arrow/record_batch.h" +#include "arrow/scalar.h" +#include "arrow/table.h" #include "arrow/testing/builder.h" #include "arrow/testing/gtest_util.h" -using ::testing::_; -using ::testing::Ref; - using arrow::internal::checked_cast; -namespace arrow { -namespace flight { -namespace sql { +namespace arrow::flight::sql { /// \brief Auxiliary variant visitor used to assert that GetSqlInfo's values are /// correctly placed on its DenseUnionArray @@ -158,7 +156,7 @@ class TestFlightSqlServer : public ::testing::Test { ASSERT_OK_AND_ASSIGN(location, Location::ForGrpcTcp("localhost", server->port())); ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(location)); - sql_client.reset(new FlightSqlClient(std::move(client))); + sql_client = std::make_unique(std::move(client)); } void TearDown() override { @@ -185,11 +183,11 @@ TEST_F(TestFlightSqlServer, TestCommandStatementQuery) { arrow::schema({arrow::field("id", int64()), arrow::field("keyName", utf8()), arrow::field("value", int64()), arrow::field("foreignId", int64())}); - const auto id_array = ArrayFromJSON(int64(), R"([1, 2, 3, 4])"); + const auto id_array = ArrayFromJSON(int64(), R"([1, 2, 3, 4, 5])"); const auto keyname_array = - ArrayFromJSON(utf8(), R"(["one", "zero", "negative one", null])"); - const auto value_array = ArrayFromJSON(int64(), R"([1, 0, -1, null])"); - const auto foreignId_array = ArrayFromJSON(int64(), R"([1, 1, 1, null])"); + ArrayFromJSON(utf8(), R"(["one", "zero", "negative one", null, "null"])"); + const auto value_array = ArrayFromJSON(int64(), R"([1, 0, -1, null, null])"); + const auto foreignId_array = ArrayFromJSON(int64(), R"([1, 1, 1, null, null])"); const std::shared_ptr& expected_table = Table::Make( expected_schema, {id_array, keyname_array, value_array, foreignId_array}); @@ -215,16 +213,14 @@ TEST_F(TestFlightSqlServer, TestCommandGetTables) { ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); - ASSERT_OK_AND_ASSIGN(auto catalog_name, MakeArrayOfNull(utf8(), 3)) - ASSERT_OK_AND_ASSIGN(auto schema_name, MakeArrayOfNull(utf8(), 3)) - + const auto catalog_name = ArrayFromJSON(utf8(), R"(["main", "main", "main"])"); + ASSERT_OK_AND_ASSIGN(auto schema_name, MakeArrayOfNull(utf8(), 3)); const auto table_name = ArrayFromJSON(utf8(), R"(["foreignTable", "intTable", "sqlite_sequence"])"); const auto table_type = ArrayFromJSON(utf8(), R"(["table", "table", "table"])"); - const std::shared_ptr
& expected_table = Table::Make( + std::shared_ptr
expected_table = Table::Make( SqlSchema::GetTablesSchema(), {catalog_name, schema_name, table_name, table_type}); - AssertTablesEqual(*expected_table, *table); } @@ -246,7 +242,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetTablesWithTableFilter) { ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); - const auto catalog_name = ArrayFromJSON(utf8(), R"([null])"); + const auto catalog_name = ArrayFromJSON(utf8(), R"(["main"])"); const auto schema_name = ArrayFromJSON(utf8(), R"([null])"); const auto table_name = ArrayFromJSON(utf8(), R"(["intTable"])"); const auto table_type = ArrayFromJSON(utf8(), R"(["table"])"); @@ -298,7 +294,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetTablesWithUnexistenceTableTypeFilter) ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); - const auto catalog_name = ArrayFromJSON(utf8(), R"([null, null, null])"); + const auto catalog_name = ArrayFromJSON(utf8(), R"(["main", "main", "main"])"); const auto schema_name = ArrayFromJSON(utf8(), R"([null, null, null])"); const auto table_name = ArrayFromJSON(utf8(), R"(["foreignTable", "intTable", "sqlite_sequence"])"); @@ -330,7 +326,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetTablesWithIncludedSchemas) { const char* db_table_name = "intTable"; - const auto catalog_name = ArrayFromJSON(utf8(), R"([null])"); + const auto catalog_name = ArrayFromJSON(utf8(), R"(["main"])"); const auto schema_name = ArrayFromJSON(utf8(), R"([null])"); const auto table_name = ArrayFromJSON(utf8(), R"(["intTable"])"); const auto table_type = ArrayFromJSON(utf8(), R"(["table"])"); @@ -392,16 +388,11 @@ TEST_F(TestFlightSqlServer, TestCommandGetTypeInfoWithFiltering) { TEST_F(TestFlightSqlServer, TestCommandGetCatalogs) { ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetCatalogs({})); - ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); - - const std::shared_ptr& expected_schema = SqlSchema::GetCatalogsSchema(); - - AssertSchemaEqual(expected_schema, table->schema()); - ASSERT_EQ(0, table->num_rows()); + auto expected_table = TableFromJSON(SqlSchema::GetCatalogsSchema(), {R"([["main"]])"}); + ASSERT_NO_FATAL_FAILURE(AssertTablesEqual(*expected_table, *table, /*verbose=*/true)); } TEST_F(TestFlightSqlServer, TestCommandGetDbSchemas) { @@ -410,16 +401,12 @@ TEST_F(TestFlightSqlServer, TestCommandGetDbSchemas) { std::string* schema_filter_pattern = nullptr; ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetDbSchemas(options, catalog, schema_filter_pattern)); - ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); - - const std::shared_ptr& expected_schema = SqlSchema::GetDbSchemasSchema(); - - AssertSchemaEqual(expected_schema, table->schema()); - ASSERT_EQ(0, table->num_rows()); + auto expected_table = + TableFromJSON(SqlSchema::GetDbSchemasSchema(), {R"([["main", null]])"}); + ASSERT_NO_FATAL_FAILURE(AssertTablesEqual(*expected_table, *table, /*verbose=*/true)); } TEST_F(TestFlightSqlServer, TestCommandGetTableTypes) { @@ -485,11 +472,11 @@ TEST_F(TestFlightSqlServer, TestCommandPreparedStatementQuery) { "foreignId", int64(), example::GetColumnMetadata(SQLITE_INTEGER, db_table_name).metadata_map())}); - const auto id_array = ArrayFromJSON(int64(), R"([1, 2, 3, 4])"); + const auto id_array = ArrayFromJSON(int64(), R"([1, 2, 3, 4, 5])"); const auto keyname_array = - ArrayFromJSON(utf8(), R"(["one", "zero", "negative one", null])"); - const auto value_array = ArrayFromJSON(int64(), R"([1, 0, -1, null])"); - const auto foreignId_array = ArrayFromJSON(int64(), R"([1, 1, 1, null])"); + ArrayFromJSON(utf8(), R"(["one", "zero", "negative one", null, "null"])"); + const auto value_array = ArrayFromJSON(int64(), R"([1, 0, -1, null, null])"); + const auto foreignId_array = ArrayFromJSON(int64(), R"([1, 1, 1, null, null])"); const std::shared_ptr
& expected_table = Table::Make( expected_schema, {id_array, keyname_array, value_array, foreignId_array}); @@ -502,51 +489,57 @@ TEST_F(TestFlightSqlServer, TestCommandPreparedStatementQueryWithParameterBindin auto prepared_statement, sql_client->Prepare({}, "SELECT * FROM intTable WHERE keyName LIKE ?")); - auto parameter_schema = prepared_statement->parameter_schema(); - + const std::shared_ptr& parameter_schema = + prepared_statement->parameter_schema(); const std::shared_ptr& expected_parameter_schema = arrow::schema({arrow::field("parameter_1", example::GetUnknownColumnDataType())}); + ASSERT_NO_FATAL_FAILURE(AssertSchemaEqual(expected_parameter_schema, parameter_schema)); - AssertSchemaEqual(expected_parameter_schema, parameter_schema); - - std::shared_ptr type_ids = ArrayFromJSON(int8(), R"([0])"); - std::shared_ptr offsets = ArrayFromJSON(int32(), R"([0])"); - std::shared_ptr string_array = ArrayFromJSON(utf8(), R"(["%one"])"); - std::shared_ptr bytes_array = ArrayFromJSON(binary(), R"([])"); - std::shared_ptr bigint_array = ArrayFromJSON(int64(), R"([])"); - std::shared_ptr double_array = ArrayFromJSON(float64(), R"([])"); - - ASSERT_OK_AND_ASSIGN( - auto parameter_1_array, - DenseUnionArray::Make(*type_ids, *offsets, - {string_array, bytes_array, bigint_array, double_array}, - {"string", "bytes", "bigint", "double"}, {0, 1, 2, 3})); - - const std::shared_ptr& record_batch = - RecordBatch::Make(parameter_schema, 1, {parameter_1_array}); - - ASSERT_OK(prepared_statement->SetParameters(record_batch)); + auto record_batch = RecordBatchFromJSON(parameter_schema, R"([ [[0, "%one"]] ])"); + ASSERT_OK(prepared_statement->SetParameters(std::move(record_batch))); ASSERT_OK_AND_ASSIGN(auto flight_info, prepared_statement->Execute()); - ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); const std::shared_ptr& expected_schema = arrow::schema({arrow::field("id", int64()), arrow::field("keyName", utf8()), arrow::field("value", int64()), arrow::field("foreignId", int64())}); - const auto id_array = ArrayFromJSON(int64(), R"([1, 3])"); - const auto keyname_array = ArrayFromJSON(utf8(), R"(["one", "negative one"])"); - const auto value_array = ArrayFromJSON(int64(), R"([1, -1])"); - const auto foreignId_array = ArrayFromJSON(int64(), R"([1, 1])"); - - const std::shared_ptr
& expected_table = Table::Make( - expected_schema, {id_array, keyname_array, value_array, foreignId_array}); - - AssertTablesEqual(*expected_table, *table); + auto expected_table = TableFromJSON(expected_schema, {R"([ + [1, "one", 1, 1], + [3, "negative one", -1, 1] + ])"}); + ASSERT_NO_FATAL_FAILURE(AssertTablesEqual(*expected_table, *table, /*verbose=*/true)); + + // Set multiple parameters at once + record_batch = RecordBatchFromJSON( + parameter_schema, R"([ [[0, "%one"]], [[0, "%zero"]], [[0, "null"]] ])"); + ASSERT_OK(prepared_statement->SetParameters(std::move(record_batch))); + ASSERT_OK_AND_ASSIGN(flight_info, prepared_statement->Execute()); + ASSERT_OK_AND_ASSIGN(stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + ASSERT_OK_AND_ASSIGN(table, stream->ToTable()); + expected_table = TableFromJSON(expected_schema, {R"([ + [1, "one", 1, 1], + [3, "negative one", -1, 1], + [2, "zero", 0, 1], + [5, "null", null, null] + ])"}); + ASSERT_NO_FATAL_FAILURE(AssertTablesEqual(*expected_table, *table, /*verbose=*/true)); + + // Set a stream of parameters + ASSERT_OK_AND_ASSIGN( + auto reader, + RecordBatchReader::Make({ + RecordBatchFromJSON(parameter_schema, R"([ [[0, "%one"]], [[0, "%zero"]] ])"), + RecordBatchFromJSON(parameter_schema, R"([ [[0, "%null%"]] ])"), + })); + ASSERT_OK(prepared_statement->SetParameters(std::move(reader))); + ASSERT_OK_AND_ASSIGN(flight_info, prepared_statement->Execute()); + ASSERT_OK_AND_ASSIGN(stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + ASSERT_OK_AND_ASSIGN(table, stream->ToTable()); + ASSERT_NO_FATAL_FAILURE(AssertTablesEqual(*expected_table, *table, /*verbose=*/true)); } TEST_F(TestFlightSqlServer, TestCommandPreparedStatementUpdateWithParameterBinding) { @@ -555,41 +548,38 @@ TEST_F(TestFlightSqlServer, TestCommandPreparedStatementUpdateWithParameterBindi sql_client->Prepare( {}, "INSERT INTO INTTABLE (keyName, value) VALUES ('new_value', ?)")); - auto parameter_schema = prepared_statement->parameter_schema(); - + const std::shared_ptr& parameter_schema = + prepared_statement->parameter_schema(); const std::shared_ptr& expected_parameter_schema = arrow::schema({arrow::field("parameter_1", example::GetUnknownColumnDataType())}); + ASSERT_NO_FATAL_FAILURE(AssertSchemaEqual(expected_parameter_schema, parameter_schema)); - AssertSchemaEqual(expected_parameter_schema, parameter_schema); - - std::shared_ptr type_ids = ArrayFromJSON(int8(), R"([2])"); - std::shared_ptr offsets = ArrayFromJSON(int32(), R"([0])"); - std::shared_ptr string_array = ArrayFromJSON(utf8(), R"([])"); - std::shared_ptr bytes_array = ArrayFromJSON(binary(), R"([])"); - std::shared_ptr bigint_array = ArrayFromJSON(int64(), R"([999])"); - std::shared_ptr double_array = ArrayFromJSON(float64(), R"([])"); - - ASSERT_OK_AND_ASSIGN( - auto parameter_1_array, - DenseUnionArray::Make(*type_ids, *offsets, - {string_array, bytes_array, bigint_array, double_array}, - {"string", "bytes", "bigint", "double"}, {0, 1, 2, 3})); - - const std::shared_ptr& record_batch = - RecordBatch::Make(parameter_schema, 1, {parameter_1_array}); - - ASSERT_OK(prepared_statement->SetParameters(record_batch)); - - ASSERT_OK_AND_EQ(4, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); - - ASSERT_OK_AND_EQ(1, prepared_statement->ExecuteUpdate()); + auto record_batch = RecordBatchFromJSON(parameter_schema, R"([ [[2, 999]] ])"); + ASSERT_OK(prepared_statement->SetParameters(std::move(record_batch))); ASSERT_OK_AND_EQ(5, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); - + ASSERT_OK_AND_EQ(1, prepared_statement->ExecuteUpdate()); + ASSERT_OK_AND_EQ(6, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); ASSERT_OK_AND_EQ(1, sql_client->ExecuteUpdate( {}, "DELETE FROM intTable WHERE keyName = 'new_value'")); + ASSERT_OK_AND_EQ(5, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); + + // Set multiple parameters at once + record_batch = RecordBatchFromJSON(parameter_schema, R"([ [[2, 999]], [[2, 42]] ])"); + ASSERT_OK(prepared_statement->SetParameters(std::move(record_batch))); + ASSERT_OK_AND_EQ(2, prepared_statement->ExecuteUpdate()); + ASSERT_OK_AND_EQ(7, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); - ASSERT_OK_AND_EQ(4, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); + // Set a stream of parameters + ASSERT_OK_AND_ASSIGN( + auto reader, + RecordBatchReader::Make({ + RecordBatchFromJSON(parameter_schema, R"([ [[2, 999]], [[2, 42]] ])"), + RecordBatchFromJSON(parameter_schema, R"([ [[2, -1]] ])"), + })); + ASSERT_OK(prepared_statement->SetParameters(std::move(reader))); + ASSERT_OK_AND_EQ(3, prepared_statement->ExecuteUpdate()); + ASSERT_OK_AND_EQ(10, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); } TEST_F(TestFlightSqlServer, TestCommandPreparedStatementUpdate) { @@ -598,16 +588,12 @@ TEST_F(TestFlightSqlServer, TestCommandPreparedStatementUpdate) { sql_client->Prepare( {}, "INSERT INTO INTTABLE (keyName, value) VALUES ('new_value', 999)")); - ASSERT_OK_AND_EQ(4, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); - - ASSERT_OK_AND_EQ(1, prepared_statement->ExecuteUpdate()); - ASSERT_OK_AND_EQ(5, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); - + ASSERT_OK_AND_EQ(1, prepared_statement->ExecuteUpdate()); + ASSERT_OK_AND_EQ(6, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); ASSERT_OK_AND_EQ(1, sql_client->ExecuteUpdate( {}, "DELETE FROM intTable WHERE keyName = 'new_value'")); - - ASSERT_OK_AND_EQ(4, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); + ASSERT_OK_AND_EQ(5, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); } TEST_F(TestFlightSqlServer, TestCommandGetPrimaryKeys) { @@ -819,6 +805,4 @@ TEST_F(TestFlightSqlServer, Transactions) { ASSERT_EQ(table->num_rows(), row_count); } -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace arrow::flight::sql