Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions cpp/src/arrow/flight/sql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,6 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES)
set(ARROW_FLIGHT_SQL_TEST_LIBS ${SQLite3_LIBRARIES})
set(ARROW_FLIGHT_SQL_ACERO_SRCS example/acero_server.cc)

if(NOT MSVC AND NOT MINGW)
# ARROW-16902: getting Protobuf generated code to have all the
# proper dllexport/dllimport declarations is difficult, since
# protoc does not insert them everywhere needed to satisfy both
# MinGW and MSVC, and the Protobuf team recommends against it
list(APPEND ARROW_FLIGHT_SQL_TEST_SRCS client_test.cc)
endif()

if(ARROW_COMPUTE
AND ARROW_PARQUET
AND ARROW_SUBSTRAIT)
Expand Down
78 changes: 44 additions & 34 deletions cpp/src/arrow/flight/sql/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

// Platform-specific defines
#include "arrow/flight/client.h"
#include "arrow/flight/platform.h"

#include "arrow/flight/sql/client.h"
Expand Down Expand Up @@ -208,11 +209,12 @@ arrow::Result<int64_t> FlightSqlClient::ExecuteUpdate(const FlightCallOptions& o
std::unique_ptr<FlightMetadataReader> reader;

ARROW_RETURN_NOT_OK(DoPut(options, descriptor, arrow::schema({}), &writer, &reader));

std::shared_ptr<Buffer> metadata;
ARROW_RETURN_NOT_OK(reader->ReadMetadata(&metadata));
ARROW_RETURN_NOT_OK(writer->Close());

if (!metadata) return Status::IOError("Server did not send a response");

flight_sql_pb::DoPutUpdateResult result;
if (!result.ParseFromArray(metadata->data(), static_cast<int>(metadata->size()))) {
return Status::Invalid("Unable to parse DoPutUpdateResult");
Expand Down Expand Up @@ -535,6 +537,24 @@ arrow::Result<std::shared_ptr<PreparedStatement>> PreparedStatement::ParseRespon
parameter_schema);
}

arrow::Result<std::shared_ptr<Buffer>> BindParameters(FlightClient* client,
const FlightCallOptions& options,
const FlightDescriptor& descriptor,
RecordBatchReader* params) {
ARROW_ASSIGN_OR_RAISE(auto stream,
client->DoPut(options, descriptor, params->schema()));
while (true) {
ARROW_ASSIGN_OR_RAISE(auto batch, params->Next());
if (!batch) break;
ARROW_RETURN_NOT_OK(stream.writer->WriteRecordBatch(*batch));
}
ARROW_RETURN_NOT_OK(stream.writer->DoneWriting());
std::shared_ptr<Buffer> metadata;
ARROW_RETURN_NOT_OK(stream.reader->ReadMetadata(&metadata));
ARROW_RETURN_NOT_OK(stream.writer->Close());
return metadata;
}

arrow::Result<std::unique_ptr<FlightInfo>> PreparedStatement::Execute(
const FlightCallOptions& options) {
if (is_closed_) {
Expand All @@ -545,21 +565,11 @@ arrow::Result<std::unique_ptr<FlightInfo>> PreparedStatement::Execute(
command.set_prepared_statement_handle(handle_);
ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor,
GetFlightDescriptorForCommand(command));

if (parameter_binding_ && parameter_binding_->num_rows() > 0) {
std::unique_ptr<FlightStreamWriter> writer;
std::unique_ptr<FlightMetadataReader> reader;
ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, parameter_binding_->schema(),
&writer, &reader));

ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*parameter_binding_));
ARROW_RETURN_NOT_OK(writer->DoneWriting());
// Wait for the server to ack the result
std::shared_ptr<Buffer> buffer;
ARROW_RETURN_NOT_OK(reader->ReadMetadata(&buffer));
ARROW_RETURN_NOT_OK(writer->Close());
if (parameter_binding_) {
ARROW_ASSIGN_OR_RAISE(auto metadata,
BindParameters(client_->impl_.get(), options, descriptor,
parameter_binding_.get()));
}

ARROW_ASSIGN_OR_RAISE(auto flight_info, client_->GetFlightInfo(options, descriptor));
return std::move(flight_info);
}
Expand All @@ -574,26 +584,19 @@ arrow::Result<int64_t> PreparedStatement::ExecuteUpdate(
command.set_prepared_statement_handle(handle_);
ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor,
GetFlightDescriptorForCommand(command));
std::unique_ptr<FlightStreamWriter> writer;
std::unique_ptr<FlightMetadataReader> reader;

if (parameter_binding_ && parameter_binding_->num_rows() > 0) {
ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, parameter_binding_->schema(),
&writer, &reader));
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*parameter_binding_));
std::shared_ptr<Buffer> metadata;
if (parameter_binding_) {
ARROW_ASSIGN_OR_RAISE(metadata, BindParameters(client_->impl_.get(), options,
descriptor, parameter_binding_.get()));
} else {
const std::shared_ptr<Schema> schema = arrow::schema({});
ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, schema, &writer, &reader));
const ArrayVector columns;
const auto& record_batch = arrow::RecordBatch::Make(schema, 0, columns);
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*record_batch));
ARROW_ASSIGN_OR_RAISE(auto params, RecordBatchReader::Make({}, schema));
ARROW_ASSIGN_OR_RAISE(metadata, BindParameters(client_->impl_.get(), options,
descriptor, params.get()));
}
if (!metadata) {
return Status::IOError("Server did not send a response to ", command.GetTypeName());
}

ARROW_RETURN_NOT_OK(writer->DoneWriting());
std::shared_ptr<Buffer> metadata;
ARROW_RETURN_NOT_OK(reader->ReadMetadata(&metadata));
ARROW_RETURN_NOT_OK(writer->Close());

flight_sql_pb::DoPutUpdateResult result;
if (!result.ParseFromArray(metadata->data(), static_cast<int>(metadata->size()))) {
return Status::Invalid("Unable to parse DoPutUpdateResult object.");
Expand All @@ -603,18 +606,25 @@ arrow::Result<int64_t> PreparedStatement::ExecuteUpdate(
}

Status PreparedStatement::SetParameters(std::shared_ptr<RecordBatch> parameter_binding) {
ARROW_ASSIGN_OR_RAISE(parameter_binding_,
RecordBatchReader::Make({std::move(parameter_binding)}));
return Status::OK();
}

Status PreparedStatement::SetParameters(
std::shared_ptr<RecordBatchReader> parameter_binding) {
parameter_binding_ = std::move(parameter_binding);

return Status::OK();
}

bool PreparedStatement::IsClosed() const { return is_closed_; }

std::shared_ptr<Schema> PreparedStatement::dataset_schema() const {
const std::shared_ptr<Schema>& PreparedStatement::dataset_schema() const {
return dataset_schema_;
}

std::shared_ptr<Schema> PreparedStatement::parameter_schema() const {
const std::shared_ptr<Schema>& PreparedStatement::parameter_schema() const {
return parameter_schema_;
}

Expand Down
11 changes: 6 additions & 5 deletions cpp/src/arrow/flight/sql/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,17 +392,18 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement {

/// \brief Retrieve the parameter schema from the query.
/// \return The parameter schema from the query.
std::shared_ptr<Schema> parameter_schema() const;
const std::shared_ptr<Schema>& parameter_schema() const;

/// \brief Retrieve the ResultSet schema from the query.
/// \return The ResultSet schema from the query.
std::shared_ptr<Schema> dataset_schema() const;
const std::shared_ptr<Schema>& dataset_schema() const;

/// \brief Set a RecordBatch that contains the parameters that will be bound.
/// \param parameter_binding The parameters that will be bound.
/// \return Status.
Status SetParameters(std::shared_ptr<RecordBatch> parameter_binding);

/// \brief Set a RecordBatchReader that contains the parameters that will be bound.
Status SetParameters(std::shared_ptr<RecordBatchReader> parameter_binding);

/// \brief Re-request the result set schema from the server (should
/// be identical to dataset_schema).
arrow::Result<std::unique_ptr<SchemaResult>> GetSchema(
Expand All @@ -422,7 +423,7 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement {
std::string handle_;
std::shared_ptr<Schema> dataset_schema_;
std::shared_ptr<Schema> parameter_schema_;
std::shared_ptr<RecordBatch> parameter_binding_;
std::shared_ptr<RecordBatchReader> parameter_binding_;
bool is_closed_;
};

Expand Down
Loading