diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 5ababa32953..b14559d12a1 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -130,6 +130,7 @@ jobs: ARROW_PLASMA: ON ARROW_PYTHON: ON ARROW_S3: ON + ARROW_SUBSTRAIT: ON ARROW_WITH_ZLIB: ON ARROW_WITH_LZ4: ON ARROW_WITH_BZ2: ON diff --git a/ci/scripts/python_build.sh b/ci/scripts/python_build.sh index e87117ce877..b90321643c7 100755 --- a/ci/scripts/python_build.sh +++ b/ci/scripts/python_build.sh @@ -64,6 +64,7 @@ export PYARROW_WITH_PLASMA=${ARROW_PLASMA:-OFF} export PYARROW_WITH_PARQUET=${ARROW_PARQUET:-OFF} export PYARROW_WITH_PARQUET_ENCRYPTION=${PARQUET_REQUIRE_ENCRYPTION:-ON} export PYARROW_WITH_S3=${ARROW_S3:-OFF} +export PYARROW_WITH_SUBSTRAIT=${ARROW_SUBSTRAIT:-OFF} export PYARROW_PARALLEL=${n_jobs} diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt index d09b8819fb1..ea9797ea1d7 100644 --- a/cpp/src/arrow/engine/CMakeLists.txt +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -26,7 +26,8 @@ set(ARROW_SUBSTRAIT_SRCS substrait/serde.cc substrait/plan_internal.cc substrait/relation_internal.cc - substrait/type_internal.cc) + substrait/type_internal.cc + substrait/util.cc) add_arrow_lib(arrow_substrait CMAKE_PACKAGE_NAME diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 775e2520e0b..deee2d14456 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -16,6 +16,7 @@ // under the License. #include "arrow/engine/substrait/serde.h" +#include "arrow/engine/substrait/util.h" #include #include @@ -752,5 +753,66 @@ TEST(Substrait, ExtensionSetFromPlanMissingFunc) { &ext_set)); } +Result GetSubstraitJSON() { + ARROW_ASSIGN_OR_RAISE(std::string dir_string, + arrow::internal::GetEnvVar("PARQUET_TEST_DATA")); + auto file_name = + arrow::internal::PlatformFilename::FromString(dir_string)->Join("binary.parquet"); + auto file_path = file_name->ToString(); + std::string substrait_json = R"({ + "relations": [ + {"rel": { + "read": { + "base_schema": { + "struct": { + "types": [ + {"binary": {}} + ] + }, + "names": [ + "foo" + ] + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER", + "format": "FILE_FORMAT_PARQUET" + } + ] + } + } + }} + ] + })"; + std::string filename_placeholder = "FILENAME_PLACEHOLDER"; + substrait_json.replace(substrait_json.find(filename_placeholder), + filename_placeholder.size(), file_path); + return substrait_json; +} + +TEST(Substrait, GetRecordBatchReader) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#else + ASSERT_OK_AND_ASSIGN(std::string substrait_json, GetSubstraitJSON()); + ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json)); + ASSERT_OK_AND_ASSIGN(auto reader, substrait::ExecuteSerializedPlan(*buf)); + ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatchReader(reader.get())); + // Note: assuming the binary.parquet file contains fixed amount of records + // in case of a test failure, re-evalaute the content in the file + EXPECT_EQ(table->num_rows(), 12); +#endif +} + +TEST(Substrait, InvalidPlan) { + std::string substrait_json = R"({ + "relations": [ + ] + })"; + ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json)); + ASSERT_RAISES(Invalid, substrait::ExecuteSerializedPlan(*buf)); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/type_internal.cc b/cpp/src/arrow/engine/substrait/type_internal.cc index c1dac97b682..c7b94b41040 100644 --- a/cpp/src/arrow/engine/substrait/type_internal.cc +++ b/cpp/src/arrow/engine/substrait/type_internal.cc @@ -46,7 +46,7 @@ Status CheckVariation(const TypeMessage& type) { template bool IsNullable(const TypeMessage& type) { // FIXME what can we do with NULLABILITY_UNSPECIFIED - return type.nullability() != substrait::Type::NULLABILITY_REQUIRED; + return type.nullability() != ::substrait::Type::NULLABILITY_REQUIRED; } template @@ -99,66 +99,66 @@ Result FieldsFromProto(int size, const Types& types, } // namespace Result, bool>> FromProto( - const substrait::Type& type, const ExtensionSet& ext_set) { + const ::substrait::Type& type, const ExtensionSet& ext_set) { switch (type.kind_case()) { - case substrait::Type::kBool: + case ::substrait::Type::kBool: return FromProtoImpl(type.bool_()); - case substrait::Type::kI8: + case ::substrait::Type::kI8: return FromProtoImpl(type.i8()); - case substrait::Type::kI16: + case ::substrait::Type::kI16: return FromProtoImpl(type.i16()); - case substrait::Type::kI32: + case ::substrait::Type::kI32: return FromProtoImpl(type.i32()); - case substrait::Type::kI64: + case ::substrait::Type::kI64: return FromProtoImpl(type.i64()); - case substrait::Type::kFp32: + case ::substrait::Type::kFp32: return FromProtoImpl(type.fp32()); - case substrait::Type::kFp64: + case ::substrait::Type::kFp64: return FromProtoImpl(type.fp64()); - case substrait::Type::kString: + case ::substrait::Type::kString: return FromProtoImpl(type.string()); - case substrait::Type::kBinary: + case ::substrait::Type::kBinary: return FromProtoImpl(type.binary()); - case substrait::Type::kTimestamp: + case ::substrait::Type::kTimestamp: return FromProtoImpl(type.timestamp(), TimeUnit::MICRO); - case substrait::Type::kTimestampTz: + case ::substrait::Type::kTimestampTz: return FromProtoImpl(type.timestamp_tz(), TimeUnit::MICRO, TimestampTzTimezoneString()); - case substrait::Type::kDate: + case ::substrait::Type::kDate: return FromProtoImpl(type.date()); - case substrait::Type::kTime: + case ::substrait::Type::kTime: return FromProtoImpl(type.time(), TimeUnit::MICRO); - case substrait::Type::kIntervalYear: + case ::substrait::Type::kIntervalYear: return FromProtoImpl(type.interval_year(), interval_year); - case substrait::Type::kIntervalDay: + case ::substrait::Type::kIntervalDay: return FromProtoImpl(type.interval_day(), interval_day); - case substrait::Type::kUuid: + case ::substrait::Type::kUuid: return FromProtoImpl(type.uuid(), uuid); - case substrait::Type::kFixedChar: + case ::substrait::Type::kFixedChar: return FromProtoImpl(type.fixed_char(), fixed_char, type.fixed_char().length()); - case substrait::Type::kVarchar: + case ::substrait::Type::kVarchar: return FromProtoImpl(type.varchar(), varchar, type.varchar().length()); - case substrait::Type::kFixedBinary: + case ::substrait::Type::kFixedBinary: return FromProtoImpl(type.fixed_binary(), type.fixed_binary().length()); - case substrait::Type::kDecimal: { + case ::substrait::Type::kDecimal: { const auto& decimal = type.decimal(); return FromProtoImpl(decimal, decimal.precision(), decimal.scale()); } - case substrait::Type::kStruct: { + case ::substrait::Type::kStruct: { const auto& struct_ = type.struct_(); ARROW_ASSIGN_OR_RAISE(auto fields, FieldsFromProto( @@ -168,7 +168,7 @@ Result, bool>> FromProto( return FromProtoImpl(struct_, std::move(fields)); } - case substrait::Type::kList: { + case ::substrait::Type::kList: { const auto& list = type.list(); if (!list.has_type()) { @@ -182,7 +182,7 @@ Result, bool>> FromProto( list, field("item", std::move(type_nullable.first), type_nullable.second)); } - case substrait::Type::kMap: { + case ::substrait::Type::kMap: { const auto& map = type.map(); static const std::array kMissing = {"key and value", "value", "key", @@ -206,7 +206,7 @@ Result, bool>> FromProto( field("value", std::move(value_nullable.first), value_nullable.second)); } - case substrait::Type::kUserDefinedTypeReference: { + case ::substrait::Type::kUserDefinedTypeReference: { uint32_t anchor = type.user_defined_type_reference(); ARROW_ASSIGN_OR_RAISE(auto type_record, ext_set.DecodeType(anchor)); return std::make_pair(std::move(type_record.type), true); @@ -226,18 +226,20 @@ struct DataTypeToProtoImpl { Status Visit(const NullType& t) { return EncodeUserDefined(t); } Status Visit(const BooleanType& t) { - return SetWith(&substrait::Type::set_allocated_bool_); + return SetWith(&::substrait::Type::set_allocated_bool_); } - Status Visit(const Int8Type& t) { return SetWith(&substrait::Type::set_allocated_i8); } + Status Visit(const Int8Type& t) { + return SetWith(&::substrait::Type::set_allocated_i8); + } Status Visit(const Int16Type& t) { - return SetWith(&substrait::Type::set_allocated_i16); + return SetWith(&::substrait::Type::set_allocated_i16); } Status Visit(const Int32Type& t) { - return SetWith(&substrait::Type::set_allocated_i32); + return SetWith(&::substrait::Type::set_allocated_i32); } Status Visit(const Int64Type& t) { - return SetWith(&substrait::Type::set_allocated_i64); + return SetWith(&::substrait::Type::set_allocated_i64); } Status Visit(const UInt8Type& t) { return EncodeUserDefined(t); } @@ -247,26 +249,27 @@ struct DataTypeToProtoImpl { Status Visit(const HalfFloatType& t) { return EncodeUserDefined(t); } Status Visit(const FloatType& t) { - return SetWith(&substrait::Type::set_allocated_fp32); + return SetWith(&::substrait::Type::set_allocated_fp32); } Status Visit(const DoubleType& t) { - return SetWith(&substrait::Type::set_allocated_fp64); + return SetWith(&::substrait::Type::set_allocated_fp64); } Status Visit(const StringType& t) { - return SetWith(&substrait::Type::set_allocated_string); + return SetWith(&::substrait::Type::set_allocated_string); } Status Visit(const BinaryType& t) { - return SetWith(&substrait::Type::set_allocated_binary); + return SetWith(&::substrait::Type::set_allocated_binary); } Status Visit(const FixedSizeBinaryType& t) { - SetWithThen(&substrait::Type::set_allocated_fixed_binary)->set_length(t.byte_width()); + SetWithThen(&::substrait::Type::set_allocated_fixed_binary) + ->set_length(t.byte_width()); return Status::OK(); } Status Visit(const Date32Type& t) { - return SetWith(&substrait::Type::set_allocated_date); + return SetWith(&::substrait::Type::set_allocated_date); } Status Visit(const Date64Type& t) { return NotImplemented(t); } @@ -274,10 +277,10 @@ struct DataTypeToProtoImpl { if (t.unit() != TimeUnit::MICRO) return NotImplemented(t); if (t.timezone() == "") { - return SetWith(&substrait::Type::set_allocated_timestamp); + return SetWith(&::substrait::Type::set_allocated_timestamp); } if (t.timezone() == TimestampTzTimezoneString()) { - return SetWith(&substrait::Type::set_allocated_timestamp_tz); + return SetWith(&::substrait::Type::set_allocated_timestamp_tz); } return NotImplemented(t); @@ -286,14 +289,14 @@ struct DataTypeToProtoImpl { Status Visit(const Time32Type& t) { return NotImplemented(t); } Status Visit(const Time64Type& t) { if (t.unit() != TimeUnit::MICRO) return NotImplemented(t); - return SetWith(&substrait::Type::set_allocated_time); + return SetWith(&::substrait::Type::set_allocated_time); } Status Visit(const MonthIntervalType& t) { return EncodeUserDefined(t); } Status Visit(const DayTimeIntervalType& t) { return EncodeUserDefined(t); } Status Visit(const Decimal128Type& t) { - auto dec = SetWithThen(&substrait::Type::set_allocated_decimal); + auto dec = SetWithThen(&::substrait::Type::set_allocated_decimal); dec->set_precision(t.precision()); dec->set_scale(t.scale()); return Status::OK(); @@ -304,18 +307,20 @@ struct DataTypeToProtoImpl { // FIXME assert default field name; custom ones won't roundtrip ARROW_ASSIGN_OR_RAISE( auto type, ToProto(*t.value_type(), t.value_field()->nullable(), ext_set_)); - SetWithThen(&substrait::Type::set_allocated_list)->set_allocated_type(type.release()); + SetWithThen(&::substrait::Type::set_allocated_list) + ->set_allocated_type(type.release()); return Status::OK(); } Status Visit(const StructType& t) { - auto types = SetWithThen(&substrait::Type::set_allocated_struct_)->mutable_types(); + auto types = SetWithThen(&::substrait::Type::set_allocated_struct_)->mutable_types(); types->Reserve(t.num_fields()); for (const auto& field : t.fields()) { if (field->metadata() != nullptr) { - return Status::Invalid("substrait::Type::Struct does not support field metadata"); + return Status::Invalid( + "::substrait::Type::Struct does not support field metadata"); } ARROW_ASSIGN_OR_RAISE(auto type, ToProto(*field->type(), field->nullable(), ext_set_)); @@ -330,7 +335,7 @@ struct DataTypeToProtoImpl { Status Visit(const MapType& t) { // FIXME assert default field names; custom ones won't roundtrip - auto map = SetWithThen(&substrait::Type::set_allocated_map); + auto map = SetWithThen(&::substrait::Type::set_allocated_map); ARROW_ASSIGN_OR_RAISE(auto key, ToProto(*t.key_type(), /*nullable=*/false, ext_set_)); map->set_allocated_key(key.release()); @@ -344,25 +349,25 @@ struct DataTypeToProtoImpl { Status Visit(const ExtensionType& t) { if (UnwrapUuid(t)) { - return SetWith(&substrait::Type::set_allocated_uuid); + return SetWith(&::substrait::Type::set_allocated_uuid); } if (auto length = UnwrapFixedChar(t)) { - SetWithThen(&substrait::Type::set_allocated_fixed_char)->set_length(*length); + SetWithThen(&::substrait::Type::set_allocated_fixed_char)->set_length(*length); return Status::OK(); } if (auto length = UnwrapVarChar(t)) { - SetWithThen(&substrait::Type::set_allocated_varchar)->set_length(*length); + SetWithThen(&::substrait::Type::set_allocated_varchar)->set_length(*length); return Status::OK(); } if (UnwrapIntervalYear(t)) { - return SetWith(&substrait::Type::set_allocated_interval_year); + return SetWith(&::substrait::Type::set_allocated_interval_year); } if (UnwrapIntervalDay(t)) { - return SetWith(&substrait::Type::set_allocated_interval_day); + return SetWith(&::substrait::Type::set_allocated_interval_day); } return NotImplemented(t); @@ -376,10 +381,10 @@ struct DataTypeToProtoImpl { Status Visit(const MonthDayNanoIntervalType& t) { return EncodeUserDefined(t); } template - Sub* SetWithThen(void (substrait::Type::*set_allocated_sub)(Sub*)) { + Sub* SetWithThen(void (::substrait::Type::*set_allocated_sub)(Sub*)) { auto sub = internal::make_unique(); - sub->set_nullability(nullable_ ? substrait::Type::NULLABILITY_NULLABLE - : substrait::Type::NULLABILITY_REQUIRED); + sub->set_nullability(nullable_ ? ::substrait::Type::NULLABILITY_NULLABLE + : ::substrait::Type::NULLABILITY_REQUIRED); auto out = sub.get(); (type_->*set_allocated_sub)(sub.release()); @@ -387,7 +392,7 @@ struct DataTypeToProtoImpl { } template - Status SetWith(void (substrait::Type::*set_allocated_sub)(Sub*)) { + Status SetWith(void (::substrait::Type::*set_allocated_sub)(Sub*)) { return SetWithThen(set_allocated_sub), Status::OK(); } @@ -399,25 +404,25 @@ struct DataTypeToProtoImpl { } Status NotImplemented(const DataType& t) { - return Status::NotImplemented("conversion to substrait::Type from ", t.ToString()); + return Status::NotImplemented("conversion to ::substrait::Type from ", t.ToString()); } Status operator()(const DataType& type) { return VisitTypeInline(type, this); } - substrait::Type* type_; + ::substrait::Type* type_; bool nullable_; ExtensionSet* ext_set_; }; } // namespace -Result> ToProto(const DataType& type, bool nullable, - ExtensionSet* ext_set) { - auto out = internal::make_unique(); +Result> ToProto(const DataType& type, bool nullable, + ExtensionSet* ext_set) { + auto out = internal::make_unique<::substrait::Type>(); RETURN_NOT_OK((DataTypeToProtoImpl{out.get(), nullable, ext_set})(type)); return std::move(out); } -Result> FromProto(const substrait::NamedStruct& named_struct, +Result> FromProto(const ::substrait::NamedStruct& named_struct, const ExtensionSet& ext_set) { if (!named_struct.has_struct_()) { return Status::Invalid("While converting ", named_struct.DebugString(), @@ -461,25 +466,25 @@ void ToProtoGetDepthFirstNames(const FieldVector& fields, } } // namespace -Result> ToProto(const Schema& schema, - ExtensionSet* ext_set) { +Result> ToProto(const Schema& schema, + ExtensionSet* ext_set) { if (schema.metadata()) { - return Status::Invalid("substrait::NamedStruct does not support schema metadata"); + return Status::Invalid("::substrait::NamedStruct does not support schema metadata"); } - auto named_struct = internal::make_unique(); + auto named_struct = internal::make_unique<::substrait::NamedStruct>(); auto names = named_struct->mutable_names(); names->Reserve(schema.num_fields()); ToProtoGetDepthFirstNames(schema.fields(), names); - auto struct_ = internal::make_unique(); + auto struct_ = internal::make_unique<::substrait::Type::Struct>(); auto types = struct_->mutable_types(); types->Reserve(schema.num_fields()); for (const auto& field : schema.fields()) { if (field->metadata() != nullptr) { - return Status::Invalid("substrait::NamedStruct does not support field metadata"); + return Status::Invalid("::substrait::NamedStruct does not support field metadata"); } ARROW_ASSIGN_OR_RAISE(auto type, ToProto(*field->type(), field->nullable(), ext_set)); diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc new file mode 100644 index 00000000000..bc2aa36856e --- /dev/null +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/substrait/util.h" +#include "arrow/util/async_generator.h" +#include "arrow/util/async_util.h" + +namespace arrow { + +namespace engine { + +namespace substrait { + +namespace { + +/// \brief A SinkNodeConsumer specialized to output ExecBatches via PushGenerator +class SubstraitSinkConsumer : public compute::SinkNodeConsumer { + public: + explicit SubstraitSinkConsumer( + arrow::PushGenerator>::Producer producer) + : producer_(std::move(producer)) {} + + Status Consume(compute::ExecBatch batch) override { + // Consume a batch of data + bool did_push = producer_.Push(batch); + if (!did_push) return Status::Invalid("Producer closed already"); + return Status::OK(); + } + + Status Init(const std::shared_ptr& schema, + compute::BackpressureControl* backpressure_control) override { + schema_ = schema; + return Status::OK(); + } + + Future<> Finish() override { + ARROW_UNUSED(producer_.Close()); + return Future<>::MakeFinished(); + } + + std::shared_ptr schema() { return schema_; } + + private: + arrow::PushGenerator>::Producer producer_; + std::shared_ptr schema_; +}; + +/// \brief An executor to run a Substrait Query +/// This interface is provided as a utility when creating language +/// bindings for consuming a Substrait plan. +class SubstraitExecutor { + public: + explicit SubstraitExecutor(std::shared_ptr plan, + compute::ExecContext exec_context) + : plan_(std::move(plan)), exec_context_(exec_context) {} + + ~SubstraitExecutor() { ARROW_CHECK_OK(this->Close()); } + + Result> Execute() { + for (const compute::Declaration& decl : declarations_) { + RETURN_NOT_OK(decl.AddToPlan(plan_.get()).status()); + } + RETURN_NOT_OK(plan_->Validate()); + RETURN_NOT_OK(plan_->StartProducing()); + auto schema = sink_consumer_->schema(); + std::shared_ptr sink_reader = compute::MakeGeneratorReader( + std::move(schema), std::move(generator_), exec_context_.memory_pool()); + return sink_reader; + } + + Status Close() { return plan_->finished().status(); } + + Status Init(const Buffer& substrait_buffer) { + if (substrait_buffer.size() == 0) { + return Status::Invalid("Empty substrait plan is passed."); + } + sink_consumer_ = std::make_shared(generator_.producer()); + std::function()> consumer_factory = [&] { + return sink_consumer_; + }; + ARROW_ASSIGN_OR_RAISE(declarations_, + engine::DeserializePlans(substrait_buffer, consumer_factory)); + return Status::OK(); + } + + private: + arrow::PushGenerator> generator_; + std::vector declarations_; + std::shared_ptr plan_; + compute::ExecContext exec_context_; + std::shared_ptr sink_consumer_; +}; + +} // namespace + +Result> ExecuteSerializedPlan( + const Buffer& substrait_buffer) { + ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make()); + // TODO(ARROW-15732) + compute::ExecContext exec_context(arrow::default_memory_pool(), + ::arrow::internal::GetCpuThreadPool()); + SubstraitExecutor executor(std::move(plan), exec_context); + RETURN_NOT_OK(executor.Init(substrait_buffer)); + ARROW_ASSIGN_OR_RAISE(auto sink_reader, executor.Execute()); + return sink_reader; +} + +Result> SerializeJsonPlan(const std::string& substrait_json) { + return engine::internal::SubstraitFromJSON("Plan", substrait_json); +} + +} // namespace substrait + +} // namespace engine + +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h new file mode 100644 index 00000000000..860a459da2f --- /dev/null +++ b/cpp/src/arrow/engine/substrait/util.h @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include "arrow/engine/substrait/api.h" +#include "arrow/util/iterator.h" +#include "arrow/util/optional.h" + +namespace arrow { + +namespace engine { + +namespace substrait { + +/// \brief Retrieve a RecordBatchReader from a Substrait plan. +ARROW_ENGINE_EXPORT Result> ExecuteSerializedPlan( + const Buffer& substrait_buffer); + +/// \brief Get a Serialized Plan from a Substrait JSON plan. +/// This is a helper method for Python tests. +ARROW_ENGINE_EXPORT Result> SerializeJsonPlan( + const std::string& substrait_json); + +} // namespace substrait + +} // namespace engine + +} // namespace arrow diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index d17c76e2888..3e08253f329 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -69,6 +69,7 @@ endif() if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") option(PYARROW_BUILD_CUDA "Build the PyArrow CUDA support" OFF) option(PYARROW_BUILD_FLIGHT "Build the PyArrow Flight integration" OFF) + option(PYARROW_BUILD_SUBSTRAIT "Build the PyArrow Substrait integration" OFF) option(PYARROW_BUILD_DATASET "Build the PyArrow Dataset integration" OFF) option(PYARROW_BUILD_GANDIVA "Build the PyArrow Gandiva integration" OFF) option(PYARROW_BUILD_PARQUET "Build the PyArrow Parquet integration" OFF) @@ -227,6 +228,10 @@ if(PYARROW_BUILD_FLIGHT) set(ARROW_FLIGHT TRUE) endif() +if(PYARROW_BUILD_SUBSTRAIT) + set(ARROW_SUBSTRAIT TRUE) +endif() + # Arrow find_package(ArrowPython REQUIRED) include_directories(SYSTEM ${ARROW_INCLUDE_DIR}) @@ -535,6 +540,17 @@ if(PYARROW_BUILD_FLIGHT) set(CYTHON_EXTENSIONS ${CYTHON_EXTENSIONS} _flight) endif() +# Engine +if(PYARROW_BUILD_SUBSTRAIT) + find_package(ArrowSubstrait REQUIRED) + if(PYARROW_BUNDLE_ARROW_CPP) + bundle_arrow_lib(ARROW_SUBSTRAIT_SHARED_LIB SO_VERSION ${ARROW_SO_VERSION}) + endif() + + set(SUBSTRAIT_LINK_LIBS arrow_substrait_shared) + set(CYTHON_EXTENSIONS ${CYTHON_EXTENSIONS} _substrait) +endif() + # Gandiva if(PYARROW_BUILD_GANDIVA) find_package(Gandiva REQUIRED) @@ -625,6 +641,10 @@ if(PYARROW_BUILD_FLIGHT) target_link_libraries(_flight PRIVATE ${FLIGHT_LINK_LIBS}) endif() +if(PYARROW_BUILD_SUBSTRAIT) + target_link_libraries(_substrait PRIVATE ${SUBSTRAIT_LINK_LIBS}) +endif() + if(PYARROW_BUILD_DATASET) target_link_libraries(_dataset PRIVATE ${DATASET_LINK_LIBS}) target_link_libraries(_exec_plan PRIVATE ${DATASET_LINK_LIBS}) diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx new file mode 100644 index 00000000000..7f079fb717b --- /dev/null +++ b/python/pyarrow/_substrait.pyx @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 +from cython.operator cimport dereference as deref + +from pyarrow import Buffer +from pyarrow.lib cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_substrait cimport * + + +def run_query(plan): + """ + Execute a Substrait plan and read the results as a RecordBatchReader. + + Parameters + ---------- + plan : Buffer + The serialized Substrait plan to execute. + """ + + cdef: + CResult[shared_ptr[CRecordBatchReader]] c_res_reader + shared_ptr[CRecordBatchReader] c_reader + RecordBatchReader reader + c_string c_str_plan + shared_ptr[CBuffer] c_buf_plan + + c_buf_plan = pyarrow_unwrap_buffer(plan) + with nogil: + c_res_reader = ExecuteSerializedPlan(deref(c_buf_plan)) + + c_reader = GetResultValue(c_res_reader) + + reader = RecordBatchReader.__new__(RecordBatchReader) + reader.reader = c_reader + return reader + + +def _parse_json_plan(plan): + """ + Parse a JSON plan into equivalent serialized Protobuf. + + Parameters + ---------- + plan: bytes + Substrait plan in JSON. + + Returns + ------- + Buffer + A buffer containing the serialized Protobuf plan. + """ + + cdef: + CResult[shared_ptr[CBuffer]] c_res_buffer + c_string c_str_plan + shared_ptr[CBuffer] c_buf_plan + + c_str_plan = plan + c_res_buffer = SerializeJsonPlan(c_str_plan) + with nogil: + c_buf_plan = GetResultValue(c_res_buffer) + return pyarrow_wrap_buffer(c_buf_plan) diff --git a/python/pyarrow/includes/libarrow_substrait.pxd b/python/pyarrow/includes/libarrow_substrait.pxd new file mode 100644 index 00000000000..2e1a17b06bd --- /dev/null +++ b/python/pyarrow/includes/libarrow_substrait.pxd @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * + + +cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine::substrait" nogil: + CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(const CBuffer& substrait_buffer) + CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string& substrait_json) diff --git a/python/pyarrow/substrait.py b/python/pyarrow/substrait.py new file mode 100644 index 00000000000..e3ff28f4eba --- /dev/null +++ b/python/pyarrow/substrait.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyarrow._substrait import ( # noqa + run_query, +) diff --git a/python/pyarrow/tests/conftest.py b/python/pyarrow/tests/conftest.py index 466b1647fdd..a5aae6f634e 100644 --- a/python/pyarrow/tests/conftest.py +++ b/python/pyarrow/tests/conftest.py @@ -66,6 +66,7 @@ 'plasma', 's3', 'snappy', + 'substrait', 'tensorflow', 'flight', 'slow', @@ -98,6 +99,7 @@ 's3': False, 'slow': False, 'snappy': Codec.is_available('snappy'), + 'substrait': False, 'tensorflow': False, 'zstd': Codec.is_available('zstd'), } @@ -181,6 +183,12 @@ except ImportError: pass +try: + import pyarrow.substrait # noqa + defaults['substrait'] = True +except ImportError: + pass + def pytest_addoption(parser): # Create options to selectively enable test groups diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py new file mode 100644 index 00000000000..8df35bbba44 --- /dev/null +++ b/python/pyarrow/tests/test_substrait.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import sys +import pytest + +import pyarrow as pa +from pyarrow.lib import tobytes +from pyarrow.lib import ArrowInvalid + +try: + import pyarrow.substrait as substrait +except ImportError: + substrait = None + +# Marks all of the tests in this module +# Ignore these with pytest ... -m 'not substrait' +pytestmark = [pytest.mark.dataset, pytest.mark.substrait] + + +@pytest.mark.skipif(sys.platform == 'win32', + reason="ARROW-16392: file based URI is" + + " not fully supported for Windows") +def test_run_serialized_query(tmpdir): + substrait_query = """ + { + "relations": [ + {"rel": { + "read": { + "base_schema": { + "struct": { + "types": [ + {"i64": {}} + ] + }, + "names": [ + "foo" + ] + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER" + } + ] + } + } + }} + ] + } + """ + # TODO: replace with ipc when the support is finalized in C++ + path = os.path.join(str(tmpdir), 'substrait_data.arrow') + table = pa.table([[1, 2, 3, 4, 5]], names=['foo']) + with pa.ipc.RecordBatchFileWriter(path, schema=table.schema) as writer: + writer.write_table(table) + + query = tobytes(substrait_query.replace("FILENAME_PLACEHOLDER", path)) + + buf = pa._substrait._parse_json_plan(query) + + reader = substrait.run_query(buf) + res_tb = reader.read_all() + + assert table.select(["foo"]) == res_tb.select(["foo"]) + + +def test_invalid_plan(): + query = """ + { + "relations": [ + ] + } + """ + buf = pa._substrait._parse_json_plan(tobytes(query)) + exec_message = "Empty substrait plan is passed." + with pytest.raises(ArrowInvalid, match=exec_message): + substrait.run_query(buf) diff --git a/python/setup.py b/python/setup.py index 1189357b234..79ec3c8447e 100755 --- a/python/setup.py +++ b/python/setup.py @@ -108,6 +108,7 @@ def run(self): 'namespace of boost (default: boost)'), ('with-cuda', None, 'build the Cuda extension'), ('with-flight', None, 'build the Flight extension'), + ('with-substrait', None, 'build the Substrait extension'), ('with-dataset', None, 'build the Dataset extension'), ('with-parquet', None, 'build the Parquet extension'), ('with-parquet-encryption', None, @@ -160,6 +161,8 @@ def initialize_options(self): os.environ.get('PYARROW_WITH_HDFS', '0')) self.with_cuda = strtobool( os.environ.get('PYARROW_WITH_CUDA', '0')) + self.with_substrait = strtobool( + os.environ.get('PYARROW_WITH_SUBSTRAIT', '0')) self.with_flight = strtobool( os.environ.get('PYARROW_WITH_FLIGHT', '0')) self.with_dataset = strtobool( @@ -214,6 +217,7 @@ def initialize_options(self): '_orc', '_plasma', '_s3fs', + '_substrait', '_hdfs', '_hdfsio', 'gandiva'] @@ -268,6 +272,7 @@ def append_cmake_bool(value, varname): cmake_options += ['-G', self.cmake_generator] append_cmake_bool(self.with_cuda, 'PYARROW_BUILD_CUDA') + append_cmake_bool(self.with_substrait, 'PYARROW_BUILD_SUBSTRAIT') append_cmake_bool(self.with_flight, 'PYARROW_BUILD_FLIGHT') append_cmake_bool(self.with_gandiva, 'PYARROW_BUILD_GANDIVA') append_cmake_bool(self.with_dataset, 'PYARROW_BUILD_DATASET') @@ -393,6 +398,8 @@ def _bundle_arrow_cpp(self, build_prefix, build_lib): move_shared_libs(build_prefix, build_lib, "arrow_python") if self.with_cuda: move_shared_libs(build_prefix, build_lib, "arrow_cuda") + if self.with_substrait: + move_shared_libs(build_prefix, build_lib, "arrow_substrait") if self.with_flight: move_shared_libs(build_prefix, build_lib, "arrow_flight") move_shared_libs(build_prefix, build_lib, @@ -438,6 +445,8 @@ def _failure_permitted(self, name): return True if name == '_flight' and not self.with_flight: return True + if name == '_substrait' and not self.with_substrait: + return True if name == '_s3fs' and not self.with_s3: return True if name == '_hdfs' and not self.with_hdfs: