diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 79b48461f9b..484c3e9e769 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -371,6 +371,7 @@ if(ARROW_COMPUTE) compute/exec/exec_plan.cc compute/exec/expression.cc compute/function.cc + compute/function_internal.cc compute/kernel.cc compute/registry.cc compute/kernels/aggregate_basic.cc diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index a97bf134604..682baab208d 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -397,6 +397,32 @@ TEST_F(TestArray, TestMakeArrayOfNullUnion) { } } +void AssertAppendScalar(MemoryPool* pool, const std::shared_ptr& scalar) { + std::unique_ptr builder; + auto null_scalar = MakeNullScalar(scalar->type); + ASSERT_OK(MakeBuilder(pool, scalar->type, &builder)); + ASSERT_OK(builder->AppendScalar(*scalar)); + ASSERT_OK(builder->AppendScalar(*scalar)); + ASSERT_OK(builder->AppendScalar(*null_scalar)); + ASSERT_OK(builder->AppendScalars({scalar, null_scalar})); + ASSERT_OK(builder->AppendScalar(*scalar, /*n_repeats=*/2)); + ASSERT_OK(builder->AppendScalar(*null_scalar, /*n_repeats=*/2)); + + std::shared_ptr out; + FinishAndCheckPadding(builder.get(), &out); + ASSERT_OK(out->ValidateFull()); + ASSERT_EQ(out->length(), 9); + ASSERT_EQ(out->null_count(), 4); + for (const auto index : {0, 1, 3, 5, 6}) { + ASSERT_FALSE(out->IsNull(index)); + ASSERT_OK_AND_ASSIGN(auto scalar_i, out->GetScalar(index)); + AssertScalarsEqual(*scalar, *scalar_i, /*verbose=*/true); + } + for (const auto index : {2, 4, 7, 8}) { + ASSERT_TRUE(out->IsNull(index)); + } +} + TEST_F(TestArray, TestMakeArrayFromScalar) { ASSERT_OK_AND_ASSIGN(auto null_array, MakeArrayFromScalar(NullScalar(), 5)); ASSERT_OK(null_array->ValidateFull()); @@ -447,6 +473,10 @@ TEST_F(TestArray, TestMakeArrayFromScalar) { ASSERT_EQ(array->null_count(), 0); } } + + for (auto scalar : scalars) { + AssertAppendScalar(pool_, scalar); + } } TEST_F(TestArray, TestMakeArrayFromDictionaryScalar) { @@ -481,6 +511,8 @@ TEST_F(TestArray, TestMakeArrayFromMapScalar) { ASSERT_OK_AND_ASSIGN(auto item, array->GetScalar(i)); ASSERT_TRUE(item->Equals(scalar)); } + + AssertAppendScalar(pool_, std::make_shared(scalar)); } TEST_F(TestArray, ValidateBuffersPrimitive) { diff --git a/cpp/src/arrow/array/builder_base.cc b/cpp/src/arrow/array/builder_base.cc index b92cc285894..c892e3d664b 100644 --- a/cpp/src/arrow/array/builder_base.cc +++ b/cpp/src/arrow/array/builder_base.cc @@ -24,8 +24,11 @@ #include "arrow/array/data.h" #include "arrow/array/util.h" #include "arrow/buffer.h" +#include "arrow/builder.h" +#include "arrow/scalar.h" #include "arrow/status.h" #include "arrow/util/logging.h" +#include "arrow/visitor_inline.h" namespace arrow { @@ -92,6 +95,162 @@ Status ArrayBuilder::Advance(int64_t elements) { return null_bitmap_builder_.Advance(elements); } +namespace { +struct AppendScalarImpl { + template + enable_if_t::value || is_decimal_type::value || + is_fixed_size_binary_type::value, + Status> + Visit(const T&) { + auto builder = internal::checked_cast::BuilderType*>(builder_); + RETURN_NOT_OK(builder->Reserve(n_repeats_ * (scalars_end_ - scalars_begin_))); + + for (int64_t i = 0; i < n_repeats_; i++) { + for (const std::shared_ptr* raw = scalars_begin_; raw != scalars_end_; + raw++) { + auto scalar = + internal::checked_cast::ScalarType*>(raw->get()); + if (scalar->is_valid) { + builder->UnsafeAppend(scalar->value); + } else { + builder->UnsafeAppendNull(); + } + } + } + return Status::OK(); + } + + template + enable_if_base_binary Visit(const T&) { + int64_t data_size = 0; + for (const std::shared_ptr* raw = scalars_begin_; raw != scalars_end_; + raw++) { + auto scalar = + internal::checked_cast::ScalarType*>(raw->get()); + if (scalar->is_valid) { + data_size += scalar->value->size(); + } + } + + auto builder = internal::checked_cast::BuilderType*>(builder_); + RETURN_NOT_OK(builder->Reserve(n_repeats_ * (scalars_end_ - scalars_begin_))); + RETURN_NOT_OK(builder->ReserveData(n_repeats_ * data_size)); + + for (int64_t i = 0; i < n_repeats_; i++) { + for (const std::shared_ptr* raw = scalars_begin_; raw != scalars_end_; + raw++) { + auto scalar = + internal::checked_cast::ScalarType*>(raw->get()); + if (scalar->is_valid) { + builder->UnsafeAppend(util::string_view{*scalar->value}); + } else { + builder->UnsafeAppendNull(); + } + } + } + return Status::OK(); + } + + template + enable_if_list_like Visit(const T&) { + auto builder = internal::checked_cast::BuilderType*>(builder_); + int64_t num_children = 0; + for (const std::shared_ptr* scalar = scalars_begin_; scalar != scalars_end_; + scalar++) { + if (!(*scalar)->is_valid) continue; + num_children += + internal::checked_cast(**scalar).value->length(); + } + RETURN_NOT_OK(builder->value_builder()->Reserve(num_children * n_repeats_)); + + for (int64_t i = 0; i < n_repeats_; i++) { + for (const std::shared_ptr* scalar = scalars_begin_; scalar != scalars_end_; + scalar++) { + if ((*scalar)->is_valid) { + RETURN_NOT_OK(builder->Append()); + const Array& list = + *internal::checked_cast(**scalar).value; + for (int64_t i = 0; i < list.length(); i++) { + ARROW_ASSIGN_OR_RAISE(auto scalar, list.GetScalar(i)); + RETURN_NOT_OK(builder->value_builder()->AppendScalar(*scalar)); + } + } else { + RETURN_NOT_OK(builder_->AppendNull()); + } + } + } + return Status::OK(); + } + + Status Visit(const StructType& type) { + auto* builder = internal::checked_cast(builder_); + auto count = n_repeats_ * (scalars_end_ - scalars_begin_); + RETURN_NOT_OK(builder->Reserve(count)); + for (int field_index = 0; field_index < type.num_fields(); ++field_index) { + RETURN_NOT_OK(builder->field_builder(field_index)->Reserve(count)); + } + for (int64_t i = 0; i < n_repeats_; i++) { + for (const std::shared_ptr* s = scalars_begin_; s != scalars_end_; s++) { + const auto& scalar = internal::checked_cast(**s); + for (int field_index = 0; field_index < type.num_fields(); ++field_index) { + if (!scalar.is_valid || !scalar.value[field_index]) { + RETURN_NOT_OK(builder->field_builder(field_index)->AppendNull()); + } else { + RETURN_NOT_OK(builder->field_builder(field_index) + ->AppendScalar(*scalar.value[field_index])); + } + } + RETURN_NOT_OK(builder->Append(scalar.is_valid)); + } + } + return Status::OK(); + } + + Status Visit(const DataType& type) { + return Status::NotImplemented("AppendScalar for type ", type); + } + + Status Convert() { return VisitTypeInline(*(*scalars_begin_)->type, this); } + + const std::shared_ptr* scalars_begin_; + const std::shared_ptr* scalars_end_; + int64_t n_repeats_; + ArrayBuilder* builder_; +}; +} // namespace + +Status ArrayBuilder::AppendScalar(const Scalar& scalar) { + if (!scalar.type->Equals(type())) { + return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(), + " to builder for type ", type()->ToString()); + } + std::shared_ptr shared{const_cast(&scalar), [](Scalar*) {}}; + return AppendScalarImpl{&shared, &shared + 1, /*n_repeats=*/1, this}.Convert(); +} + +Status ArrayBuilder::AppendScalar(const Scalar& scalar, int64_t n_repeats) { + if (!scalar.type->Equals(type())) { + return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(), + " to builder for type ", type()->ToString()); + } + std::shared_ptr shared{const_cast(&scalar), [](Scalar*) {}}; + return AppendScalarImpl{&shared, &shared + 1, n_repeats, this}.Convert(); +} + +Status ArrayBuilder::AppendScalars(const ScalarVector& scalars) { + if (scalars.empty()) return Status::OK(); + const auto ty = type(); + for (const auto& scalar : scalars) { + if (!scalar->type->Equals(ty)) { + return Status::Invalid("Cannot append scalar of type ", scalar->type->ToString(), + " to builder for type ", type()->ToString()); + } + } + return AppendScalarImpl{scalars.data(), scalars.data() + scalars.size(), + /*n_repeats=*/1, this} + .Convert(); +} + Status ArrayBuilder::Finish(std::shared_ptr* out) { std::shared_ptr internal_data; RETURN_NOT_OK(FinishInternal(&internal_data)); diff --git a/cpp/src/arrow/array/builder_base.h b/cpp/src/arrow/array/builder_base.h index 15c726241b5..8e60c306796 100644 --- a/cpp/src/arrow/array/builder_base.h +++ b/cpp/src/arrow/array/builder_base.h @@ -116,6 +116,11 @@ class ARROW_EXPORT ArrayBuilder { /// This method is useful when appending null values to a parent nested type. virtual Status AppendEmptyValues(int64_t length) = 0; + /// \brief Append a value from a scalar + Status AppendScalar(const Scalar& scalar); + Status AppendScalar(const Scalar& scalar, int64_t n_repeats); + Status AppendScalars(const ScalarVector& scalars); + /// For cases where raw data was memcpy'd into the internal buffers, allows us /// to advance the length of the builder. It is your responsibility to use /// this function responsibly. diff --git a/cpp/src/arrow/array/builder_binary.h b/cpp/src/arrow/array/builder_binary.h index c1c664a1249..7653eeca5c4 100644 --- a/cpp/src/arrow/array/builder_binary.h +++ b/cpp/src/arrow/array/builder_binary.h @@ -467,6 +467,14 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { return Status::OK(); } + Status Append(const Buffer& s) { + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppend(util::string_view(s)); + return Status::OK(); + } + + Status Append(const std::shared_ptr& s) { return Append(*s); } + template Status Append(const std::array& value) { ARROW_RETURN_NOT_OK(Reserve(1)); @@ -502,6 +510,10 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { UnsafeAppend(reinterpret_cast(value.data())); } + void UnsafeAppend(const Buffer& s) { UnsafeAppend(util::string_view(s)); } + + void UnsafeAppend(const std::shared_ptr& s) { UnsafeAppend(*s); } + void UnsafeAppendNull() { UnsafeAppendToBitmap(false); byte_builder_.UnsafeAppend(/*num_copies=*/byte_width_, 0); diff --git a/cpp/src/arrow/array/builder_dict.h b/cpp/src/arrow/array/builder_dict.h index 40d6ce1ba9a..455cb3df7b1 100644 --- a/cpp/src/arrow/array/builder_dict.h +++ b/cpp/src/arrow/array/builder_dict.h @@ -29,6 +29,7 @@ #include "arrow/array/builder_primitive.h" // IWYU pragma: export #include "arrow/array/data.h" #include "arrow/array/util.h" +#include "arrow/scalar.h" #include "arrow/status.h" #include "arrow/type.h" #include "arrow/type_traits.h" diff --git a/cpp/src/arrow/compute/api_aggregate.cc b/cpp/src/arrow/compute/api_aggregate.cc index efff4ac67df..be05c3c11d0 100644 --- a/cpp/src/arrow/compute/api_aggregate.cc +++ b/cpp/src/arrow/compute/api_aggregate.cc @@ -18,10 +18,120 @@ #include "arrow/compute/api_aggregate.h" #include "arrow/compute/exec.h" +#include "arrow/compute/function_internal.h" +#include "arrow/compute/registry.h" +#include "arrow/compute/util_internal.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/logging.h" namespace arrow { + +namespace internal { +template <> +struct EnumTraits + : BasicEnumTraits { + static std::string name() { return "QuantileOptions::Interpolation"; } + static std::string value_name(compute::QuantileOptions::Interpolation value) { + switch (value) { + case compute::QuantileOptions::LINEAR: + return "LINEAR"; + case compute::QuantileOptions::LOWER: + return "LOWER"; + case compute::QuantileOptions::HIGHER: + return "HIGHER"; + case compute::QuantileOptions::NEAREST: + return "NEAREST"; + case compute::QuantileOptions::MIDPOINT: + return "MIDPOINT"; + } + return ""; + } +}; +} // namespace internal + namespace compute { +// ---------------------------------------------------------------------- +// Function options + +using ::arrow::internal::checked_cast; + +namespace internal { +namespace { +using ::arrow::internal::DataMember; +static auto kScalarAggregateOptionsType = GetFunctionOptionsType( + DataMember("skip_nulls", &ScalarAggregateOptions::skip_nulls), + DataMember("min_count", &ScalarAggregateOptions::min_count)); +static auto kModeOptionsType = + GetFunctionOptionsType(DataMember("n", &ModeOptions::n)); +static auto kVarianceOptionsType = + GetFunctionOptionsType(DataMember("ddof", &VarianceOptions::ddof)); +static auto kQuantileOptionsType = GetFunctionOptionsType( + DataMember("q", &QuantileOptions::q), + DataMember("interpolation", &QuantileOptions::interpolation)); +static auto kTDigestOptionsType = GetFunctionOptionsType( + DataMember("q", &TDigestOptions::q), DataMember("delta", &TDigestOptions::delta), + DataMember("buffer_size", &TDigestOptions::buffer_size)); +static auto kIndexOptionsType = + GetFunctionOptionsType(DataMember("value", &IndexOptions::value)); +} // namespace +} // namespace internal + +ScalarAggregateOptions::ScalarAggregateOptions(bool skip_nulls, uint32_t min_count) + : FunctionOptions(internal::kScalarAggregateOptionsType), + skip_nulls(skip_nulls), + min_count(min_count) {} +constexpr char ScalarAggregateOptions::kTypeName[]; + +ModeOptions::ModeOptions(int64_t n) : FunctionOptions(internal::kModeOptionsType), n(n) {} +constexpr char ModeOptions::kTypeName[]; + +VarianceOptions::VarianceOptions(int ddof) + : FunctionOptions(internal::kVarianceOptionsType), ddof(ddof) {} +constexpr char VarianceOptions::kTypeName[]; + +QuantileOptions::QuantileOptions(double q, enum Interpolation interpolation) + : FunctionOptions(internal::kQuantileOptionsType), + q{q}, + interpolation{interpolation} {} +QuantileOptions::QuantileOptions(std::vector q, enum Interpolation interpolation) + : FunctionOptions(internal::kQuantileOptionsType), + q{std::move(q)}, + interpolation{interpolation} {} +constexpr char QuantileOptions::kTypeName[]; + +TDigestOptions::TDigestOptions(double q, uint32_t delta, uint32_t buffer_size) + : FunctionOptions(internal::kTDigestOptionsType), + q{q}, + delta{delta}, + buffer_size{buffer_size} {} +TDigestOptions::TDigestOptions(std::vector q, uint32_t delta, + uint32_t buffer_size) + : FunctionOptions(internal::kTDigestOptionsType), + q{std::move(q)}, + delta{delta}, + buffer_size{buffer_size} {} +constexpr char TDigestOptions::kTypeName[]; + +IndexOptions::IndexOptions(std::shared_ptr value) + : FunctionOptions(internal::kIndexOptionsType), value{std::move(value)} {} +IndexOptions::IndexOptions() : IndexOptions(std::make_shared()) {} +constexpr char IndexOptions::kTypeName[]; + +namespace internal { +void RegisterAggregateOptions(FunctionRegistry* registry) { + DCHECK_OK(registry->AddFunctionOptionsType(kScalarAggregateOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kModeOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kVarianceOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kQuantileOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kTDigestOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kIndexOptionsType)); +} +} // namespace internal + // ---------------------------------------------------------------------- // Scalar aggregates diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h index 121896f1c97..9be0b406aa4 100644 --- a/cpp/src/arrow/compute/api_aggregate.h +++ b/cpp/src/arrow/compute/api_aggregate.h @@ -43,10 +43,10 @@ class ExecContext; /// \brief Control general scalar aggregate kernel behavior /// /// By default, null values are ignored -struct ARROW_EXPORT ScalarAggregateOptions : public FunctionOptions { - explicit ScalarAggregateOptions(bool skip_nulls = true, uint32_t min_count = 1) - : skip_nulls(skip_nulls), min_count(min_count) {} - +class ARROW_EXPORT ScalarAggregateOptions : public FunctionOptions { + public: + explicit ScalarAggregateOptions(bool skip_nulls = true, uint32_t min_count = 1); + constexpr static char const kTypeName[] = "scalar_aggregate"; static ScalarAggregateOptions Defaults() { return ScalarAggregateOptions{}; } bool skip_nulls; @@ -57,9 +57,10 @@ struct ARROW_EXPORT ScalarAggregateOptions : public FunctionOptions { /// /// Returns top-n common values and counts. /// By default, returns the most common value and count. -struct ARROW_EXPORT ModeOptions : public FunctionOptions { - explicit ModeOptions(int64_t n = 1) : n(n) {} - +class ARROW_EXPORT ModeOptions : public FunctionOptions { + public: + explicit ModeOptions(int64_t n = 1); + constexpr static char const kTypeName[] = "mode"; static ModeOptions Defaults() { return ModeOptions{}; } int64_t n = 1; @@ -69,9 +70,10 @@ struct ARROW_EXPORT ModeOptions : public FunctionOptions { /// /// The divisor used in calculations is N - ddof, where N is the number of elements. /// By default, ddof is zero, and population variance or stddev is returned. -struct ARROW_EXPORT VarianceOptions : public FunctionOptions { - explicit VarianceOptions(int ddof = 0) : ddof(ddof) {} - +class ARROW_EXPORT VarianceOptions : public FunctionOptions { + public: + explicit VarianceOptions(int ddof = 0); + constexpr static char const kTypeName[] = "variance"; static VarianceOptions Defaults() { return VarianceOptions{}; } int ddof = 0; @@ -80,7 +82,8 @@ struct ARROW_EXPORT VarianceOptions : public FunctionOptions { /// \brief Control Quantile kernel behavior /// /// By default, returns the median value. -struct ARROW_EXPORT QuantileOptions : public FunctionOptions { +class ARROW_EXPORT QuantileOptions : public FunctionOptions { + public: /// Interpolation method to use when quantile lies between two data points enum Interpolation { LINEAR = 0, @@ -90,13 +93,12 @@ struct ARROW_EXPORT QuantileOptions : public FunctionOptions { MIDPOINT, }; - explicit QuantileOptions(double q = 0.5, enum Interpolation interpolation = LINEAR) - : q{q}, interpolation{interpolation} {} + explicit QuantileOptions(double q = 0.5, enum Interpolation interpolation = LINEAR); explicit QuantileOptions(std::vector q, - enum Interpolation interpolation = LINEAR) - : q{std::move(q)}, interpolation{interpolation} {} + enum Interpolation interpolation = LINEAR); + constexpr static char const kTypeName[] = "quantile"; static QuantileOptions Defaults() { return QuantileOptions{}; } /// quantile must be between 0 and 1 inclusive @@ -107,15 +109,13 @@ struct ARROW_EXPORT QuantileOptions : public FunctionOptions { /// \brief Control TDigest approximate quantile kernel behavior /// /// By default, returns the median value. -struct ARROW_EXPORT TDigestOptions : public FunctionOptions { +class ARROW_EXPORT TDigestOptions : public FunctionOptions { + public: explicit TDigestOptions(double q = 0.5, uint32_t delta = 100, - uint32_t buffer_size = 500) - : q{q}, delta{delta}, buffer_size{buffer_size} {} - + uint32_t buffer_size = 500); explicit TDigestOptions(std::vector q, uint32_t delta = 100, - uint32_t buffer_size = 500) - : q{std::move(q)}, delta{delta}, buffer_size{buffer_size} {} - + uint32_t buffer_size = 500); + constexpr static char const kTypeName[] = "t_digest"; static TDigestOptions Defaults() { return TDigestOptions{}; } /// quantile must be between 0 and 1 inclusive @@ -127,8 +127,12 @@ struct ARROW_EXPORT TDigestOptions : public FunctionOptions { }; /// \brief Control Index kernel behavior -struct ARROW_EXPORT IndexOptions : public FunctionOptions { - explicit IndexOptions(std::shared_ptr value) : value{std::move(value)} {} +class ARROW_EXPORT IndexOptions : public FunctionOptions { + public: + explicit IndexOptions(std::shared_ptr value); + // Default constructor for serialization + IndexOptions(); + constexpr static char const kTypeName[] = "index"; std::shared_ptr value; }; diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index db1cac290cf..5c8b91cf08a 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -21,13 +21,277 @@ #include #include +#include "arrow/array/array_base.h" #include "arrow/compute/exec.h" +#include "arrow/compute/function_internal.h" +#include "arrow/compute/registry.h" +#include "arrow/compute/util_internal.h" #include "arrow/status.h" #include "arrow/type.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/logging.h" namespace arrow { + +namespace internal { +template <> +struct EnumTraits + : BasicEnumTraits { + static std::string name() { return "JoinOptions::NullHandlingBehavior"; } + static std::string value_name(compute::JoinOptions::NullHandlingBehavior value) { + switch (value) { + case compute::JoinOptions::NullHandlingBehavior::EMIT_NULL: + return "EMIT_NULL"; + case compute::JoinOptions::NullHandlingBehavior::SKIP: + return "SKIP"; + case compute::JoinOptions::NullHandlingBehavior::REPLACE: + return "REPLACE"; + } + return ""; + } +}; +template <> +struct EnumTraits + : BasicEnumTraits { + static std::string name() { return "TimeUnit::type"; } + static std::string value_name(TimeUnit::type value) { + switch (value) { + case TimeUnit::type::SECOND: + return "SECOND"; + case TimeUnit::type::MILLI: + return "MILLI"; + case TimeUnit::type::MICRO: + return "MICRO"; + case TimeUnit::type::NANO: + return "NANO"; + } + return ""; + } +}; +template <> +struct EnumTraits + : BasicEnumTraits< + compute::CompareOperator, compute::CompareOperator::EQUAL, + compute::CompareOperator::NOT_EQUAL, compute::CompareOperator::GREATER, + compute::CompareOperator::GREATER_EQUAL, compute::CompareOperator::LESS, + compute::CompareOperator::LESS_EQUAL> { + static std::string name() { return "compute::CompareOperator"; } + static std::string value_name(compute::CompareOperator value) { + switch (value) { + case compute::CompareOperator::EQUAL: + return "EQUAL"; + case compute::CompareOperator::NOT_EQUAL: + return "NOT_EQUAL"; + case compute::CompareOperator::GREATER: + return "GREATER"; + case compute::CompareOperator::GREATER_EQUAL: + return "GREATER_EQUAL"; + case compute::CompareOperator::LESS: + return "LESS"; + case compute::CompareOperator::LESS_EQUAL: + return "LESS_EQUAL"; + } + return ""; + } +}; +} // namespace internal + namespace compute { +// ---------------------------------------------------------------------- +// Function options + +using ::arrow::internal::checked_cast; + +namespace internal { +namespace { +using ::arrow::internal::DataMember; +static auto kElementWiseAggregateOptionsType = + GetFunctionOptionsType( + DataMember("skip_nulls", &ElementWiseAggregateOptions::skip_nulls)); +static auto kJoinOptionsType = GetFunctionOptionsType( + DataMember("null_handling", &JoinOptions::null_handling), + DataMember("null_replacement", &JoinOptions::null_replacement)); +static auto kMatchSubstringOptionsType = GetFunctionOptionsType( + DataMember("pattern", &MatchSubstringOptions::pattern), + DataMember("ignore_case", &MatchSubstringOptions::ignore_case)); +static auto kSplitOptionsType = GetFunctionOptionsType( + DataMember("max_splits", &SplitOptions::max_splits), + DataMember("reverse", &SplitOptions::reverse)); +static auto kSplitPatternOptionsType = GetFunctionOptionsType( + DataMember("pattern", &SplitPatternOptions::pattern), + DataMember("max_splits", &SplitPatternOptions::max_splits), + DataMember("reverse", &SplitPatternOptions::reverse)); +static auto kReplaceSliceOptionsType = GetFunctionOptionsType( + DataMember("start", &ReplaceSliceOptions::start), + DataMember("stop", &ReplaceSliceOptions::stop), + DataMember("replacement", &ReplaceSliceOptions::replacement)); +static auto kReplaceSubstringOptionsType = + GetFunctionOptionsType( + DataMember("pattern", &ReplaceSubstringOptions::pattern), + DataMember("replacement", &ReplaceSubstringOptions::replacement), + DataMember("max_replacements", &ReplaceSubstringOptions::max_replacements)); +static auto kExtractRegexOptionsType = GetFunctionOptionsType( + DataMember("pattern", &ExtractRegexOptions::pattern)); +static auto kSetLookupOptionsType = GetFunctionOptionsType( + DataMember("value_set", &SetLookupOptions::value_set), + DataMember("skip_nulls", &SetLookupOptions::skip_nulls)); +static auto kStrptimeOptionsType = GetFunctionOptionsType( + DataMember("format", &StrptimeOptions::format), + DataMember("unit", &StrptimeOptions::unit)); +static auto kPadOptionsType = GetFunctionOptionsType( + DataMember("width", &PadOptions::width), DataMember("padding", &PadOptions::padding)); +static auto kTrimOptionsType = GetFunctionOptionsType( + DataMember("characters", &TrimOptions::characters)); +static auto kSliceOptionsType = GetFunctionOptionsType( + DataMember("start", &SliceOptions::start), DataMember("stop", &SliceOptions::stop), + DataMember("step", &SliceOptions::step)); +static auto kCompareOptionsType = + GetFunctionOptionsType(DataMember("op", &CompareOptions::op)); +static auto kProjectOptionsType = GetFunctionOptionsType( + DataMember("field_names", &ProjectOptions::field_names), + DataMember("field_nullability", &ProjectOptions::field_nullability), + DataMember("field_metadata", &ProjectOptions::field_metadata)); +} // namespace +} // namespace internal + +ElementWiseAggregateOptions::ElementWiseAggregateOptions(bool skip_nulls) + : FunctionOptions(internal::kElementWiseAggregateOptionsType), + skip_nulls(skip_nulls) {} +constexpr char ElementWiseAggregateOptions::kTypeName[]; + +JoinOptions::JoinOptions(NullHandlingBehavior null_handling, std::string null_replacement) + : FunctionOptions(internal::kJoinOptionsType), + null_handling(null_handling), + null_replacement(std::move(null_replacement)) {} +constexpr char JoinOptions::kTypeName[]; + +MatchSubstringOptions::MatchSubstringOptions(std::string pattern, bool ignore_case) + : FunctionOptions(internal::kMatchSubstringOptionsType), + pattern(std::move(pattern)), + ignore_case(ignore_case) {} +MatchSubstringOptions::MatchSubstringOptions() : MatchSubstringOptions("", false) {} +constexpr char MatchSubstringOptions::kTypeName[]; + +SplitOptions::SplitOptions(int64_t max_splits, bool reverse) + : FunctionOptions(internal::kSplitOptionsType), + max_splits(max_splits), + reverse(reverse) {} +constexpr char SplitOptions::kTypeName[]; + +SplitPatternOptions::SplitPatternOptions(std::string pattern, int64_t max_splits, + bool reverse) + : FunctionOptions(internal::kSplitPatternOptionsType), + pattern(std::move(pattern)), + max_splits(max_splits), + reverse(reverse) {} +SplitPatternOptions::SplitPatternOptions() : SplitPatternOptions("", -1, false) {} +constexpr char SplitPatternOptions::kTypeName[]; + +ReplaceSliceOptions::ReplaceSliceOptions(int64_t start, int64_t stop, + std::string replacement) + : FunctionOptions(internal::kReplaceSliceOptionsType), + start(start), + stop(stop), + replacement(std::move(replacement)) {} +ReplaceSliceOptions::ReplaceSliceOptions() : ReplaceSliceOptions(0, 0, "") {} +constexpr char ReplaceSliceOptions::kTypeName[]; + +ReplaceSubstringOptions::ReplaceSubstringOptions(std::string pattern, + std::string replacement, + int64_t max_replacements) + : FunctionOptions(internal::kReplaceSubstringOptionsType), + pattern(std::move(pattern)), + replacement(std::move(replacement)), + max_replacements(max_replacements) {} +ReplaceSubstringOptions::ReplaceSubstringOptions() + : ReplaceSubstringOptions("", "", -1) {} +constexpr char ReplaceSubstringOptions::kTypeName[]; + +ExtractRegexOptions::ExtractRegexOptions(std::string pattern) + : FunctionOptions(internal::kExtractRegexOptionsType), pattern(std::move(pattern)) {} +ExtractRegexOptions::ExtractRegexOptions() : ExtractRegexOptions("") {} +constexpr char ExtractRegexOptions::kTypeName[]; + +SetLookupOptions::SetLookupOptions(Datum value_set, bool skip_nulls) + : FunctionOptions(internal::kSetLookupOptionsType), + value_set(std::move(value_set)), + skip_nulls(skip_nulls) {} +SetLookupOptions::SetLookupOptions() : SetLookupOptions({}, false) {} +constexpr char SetLookupOptions::kTypeName[]; + +StrptimeOptions::StrptimeOptions(std::string format, TimeUnit::type unit) + : FunctionOptions(internal::kStrptimeOptionsType), + format(std::move(format)), + unit(unit) {} +StrptimeOptions::StrptimeOptions() : StrptimeOptions("", TimeUnit::SECOND) {} +constexpr char StrptimeOptions::kTypeName[]; + +PadOptions::PadOptions(int64_t width, std::string padding) + : FunctionOptions(internal::kPadOptionsType), + width(width), + padding(std::move(padding)) {} +PadOptions::PadOptions() : PadOptions(0, " ") {} +constexpr char PadOptions::kTypeName[]; + +TrimOptions::TrimOptions(std::string characters) + : FunctionOptions(internal::kTrimOptionsType), characters(std::move(characters)) {} +TrimOptions::TrimOptions() : TrimOptions("") {} +constexpr char TrimOptions::kTypeName[]; + +SliceOptions::SliceOptions(int64_t start, int64_t stop, int64_t step) + : FunctionOptions(internal::kSliceOptionsType), + start(start), + stop(stop), + step(step) {} +SliceOptions::SliceOptions() : SliceOptions(0, 0, 1) {} +constexpr char SliceOptions::kTypeName[]; + +CompareOptions::CompareOptions(CompareOperator op) + : FunctionOptions(internal::kCompareOptionsType), op(op) {} +CompareOptions::CompareOptions() : CompareOptions(CompareOperator::EQUAL) {} +constexpr char CompareOptions::kTypeName[]; + +ProjectOptions::ProjectOptions(std::vector n, std::vector r, + std::vector> m) + : FunctionOptions(internal::kProjectOptionsType), + field_names(std::move(n)), + field_nullability(std::move(r)), + field_metadata(std::move(m)) {} + +ProjectOptions::ProjectOptions(std::vector n) + : FunctionOptions(internal::kProjectOptionsType), + field_names(std::move(n)), + field_nullability(field_names.size(), true), + field_metadata(field_names.size(), NULLPTR) {} + +ProjectOptions::ProjectOptions() : ProjectOptions(std::vector()) {} +constexpr char ProjectOptions::kTypeName[]; + +namespace internal { +void RegisterScalarOptions(FunctionRegistry* registry) { + DCHECK_OK(registry->AddFunctionOptionsType(kElementWiseAggregateOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kJoinOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kMatchSubstringOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kSplitOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kSplitPatternOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kReplaceSliceOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kReplaceSubstringOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kSetLookupOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kStrptimeOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kPadOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kTrimOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kSliceOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kCompareOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kProjectOptionsType)); +} +} // namespace internal + #define SCALAR_EAGER_UNARY(NAME, REGISTRY_NAME) \ Result NAME(const Datum& value, ExecContext* ctx) { \ return CallFunction(REGISTRY_NAME, {value}, ctx); \ @@ -153,7 +417,7 @@ Result Compare(const Datum& left, const Datum& right, CompareOptions opti func_name = "less_equal"; break; } - return CallFunction(func_name, {left, right}, &options, ctx); + return CallFunction(func_name, {left, right}, nullptr, ctx); } // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 5c83dcb5c85..b7a514b7c52 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -37,19 +37,25 @@ namespace compute { /// /// @{ -struct ArithmeticOptions : public FunctionOptions { - ArithmeticOptions() : check_overflow(false) {} +struct ARROW_EXPORT ArithmeticOptions { + public: + explicit ArithmeticOptions(bool check_overflow = false) + : check_overflow(check_overflow) {} bool check_overflow; }; -struct ARROW_EXPORT ElementWiseAggregateOptions : public FunctionOptions { - explicit ElementWiseAggregateOptions(bool skip_nulls = true) : skip_nulls(skip_nulls) {} +class ARROW_EXPORT ElementWiseAggregateOptions : public FunctionOptions { + public: + explicit ElementWiseAggregateOptions(bool skip_nulls = true); + constexpr static char const kTypeName[] = "element_wise_aggregate"; static ElementWiseAggregateOptions Defaults() { return ElementWiseAggregateOptions{}; } + bool skip_nulls; }; /// Options for var_args_join. -struct ARROW_EXPORT JoinOptions : public FunctionOptions { +class ARROW_EXPORT JoinOptions : public FunctionOptions { + public: /// How to handle null values. (A null separator always results in a null output.) enum NullHandlingBehavior { /// A null in any input results in a null in the output. @@ -60,16 +66,18 @@ struct ARROW_EXPORT JoinOptions : public FunctionOptions { REPLACE, }; explicit JoinOptions(NullHandlingBehavior null_handling = EMIT_NULL, - std::string null_replacement = "") - : null_handling(null_handling), null_replacement(std::move(null_replacement)) {} + std::string null_replacement = ""); + constexpr static char const kTypeName[] = "join"; static JoinOptions Defaults() { return JoinOptions(); } NullHandlingBehavior null_handling; std::string null_replacement; }; -struct ARROW_EXPORT MatchSubstringOptions : public FunctionOptions { - explicit MatchSubstringOptions(std::string pattern, bool ignore_case = false) - : pattern(std::move(pattern)), ignore_case(ignore_case) {} +class ARROW_EXPORT MatchSubstringOptions : public FunctionOptions { + public: + explicit MatchSubstringOptions(std::string pattern, bool ignore_case = false); + MatchSubstringOptions(); + constexpr static char const kTypeName[] = "match_substring"; /// The exact substring (or regex, depending on kernel) to look for inside input values. std::string pattern; @@ -77,9 +85,10 @@ struct ARROW_EXPORT MatchSubstringOptions : public FunctionOptions { bool ignore_case = false; }; -struct ARROW_EXPORT SplitOptions : public FunctionOptions { - explicit SplitOptions(int64_t max_splits = -1, bool reverse = false) - : max_splits(max_splits), reverse(reverse) {} +class ARROW_EXPORT SplitOptions : public FunctionOptions { + public: + explicit SplitOptions(int64_t max_splits = -1, bool reverse = false); + constexpr static char const kTypeName[] = "split"; /// Maximum number of splits allowed, or unlimited when -1 int64_t max_splits; @@ -87,18 +96,26 @@ struct ARROW_EXPORT SplitOptions : public FunctionOptions { bool reverse; }; -struct ARROW_EXPORT SplitPatternOptions : public SplitOptions { +class ARROW_EXPORT SplitPatternOptions : public FunctionOptions { + public: explicit SplitPatternOptions(std::string pattern, int64_t max_splits = -1, - bool reverse = false) - : SplitOptions(max_splits, reverse), pattern(std::move(pattern)) {} + bool reverse = false); + SplitPatternOptions(); + constexpr static char const kTypeName[] = "split_pattern"; - /// The exact substring to look for inside input values. + /// The exact substring to split on. std::string pattern; + /// Maximum number of splits allowed, or unlimited when -1 + int64_t max_splits; + /// Start splitting from the end of the string (only relevant when max_splits != -1) + bool reverse; }; -struct ARROW_EXPORT ReplaceSliceOptions : public FunctionOptions { - explicit ReplaceSliceOptions(int64_t start, int64_t stop, std::string replacement) - : start(start), stop(stop), replacement(std::move(replacement)) {} +class ARROW_EXPORT ReplaceSliceOptions : public FunctionOptions { + public: + explicit ReplaceSliceOptions(int64_t start, int64_t stop, std::string replacement); + ReplaceSliceOptions(); + constexpr static char const kTypeName[] = "replace_slice"; /// Index to start slicing at int64_t start; @@ -108,12 +125,12 @@ struct ARROW_EXPORT ReplaceSliceOptions : public FunctionOptions { std::string replacement; }; -struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions { +class ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions { + public: explicit ReplaceSubstringOptions(std::string pattern, std::string replacement, - int64_t max_replacements = -1) - : pattern(std::move(pattern)), - replacement(std::move(replacement)), - max_replacements(max_replacements) {} + int64_t max_replacements = -1); + ReplaceSubstringOptions(); + constexpr static char const kTypeName[] = "replace_substring"; /// Pattern to match, literal, or regular expression depending on which kernel is used std::string pattern; @@ -123,17 +140,22 @@ struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions { int64_t max_replacements; }; -struct ARROW_EXPORT ExtractRegexOptions : public FunctionOptions { - explicit ExtractRegexOptions(std::string pattern) : pattern(std::move(pattern)) {} +class ARROW_EXPORT ExtractRegexOptions : public FunctionOptions { + public: + explicit ExtractRegexOptions(std::string pattern); + ExtractRegexOptions(); + constexpr static char const kTypeName[] = "extract_regex"; /// Regular expression with named capture fields std::string pattern; }; /// Options for IsIn and IndexIn functions -struct ARROW_EXPORT SetLookupOptions : public FunctionOptions { - explicit SetLookupOptions(Datum value_set, bool skip_nulls = false) - : value_set(std::move(value_set)), skip_nulls(skip_nulls) {} +class ARROW_EXPORT SetLookupOptions : public FunctionOptions { + public: + explicit SetLookupOptions(Datum value_set, bool skip_nulls = false); + SetLookupOptions(); + constexpr static char const kTypeName[] = "set_lookup"; /// The set of values to look up input values into. Datum value_set; @@ -146,17 +168,21 @@ struct ARROW_EXPORT SetLookupOptions : public FunctionOptions { bool skip_nulls; }; -struct ARROW_EXPORT StrptimeOptions : public FunctionOptions { - explicit StrptimeOptions(std::string format, TimeUnit::type unit) - : format(std::move(format)), unit(unit) {} +class ARROW_EXPORT StrptimeOptions : public FunctionOptions { + public: + explicit StrptimeOptions(std::string format, TimeUnit::type unit); + StrptimeOptions(); + constexpr static char const kTypeName[] = "strptime"; std::string format; TimeUnit::type unit; }; -struct ARROW_EXPORT PadOptions : public FunctionOptions { - explicit PadOptions(int64_t width, std::string padding = " ") - : width(width), padding(std::move(padding)) {} +class ARROW_EXPORT PadOptions : public FunctionOptions { + public: + explicit PadOptions(int64_t width, std::string padding = " "); + PadOptions(); + constexpr static char const kTypeName[] = "pad"; /// The desired string length. int64_t width; @@ -164,18 +190,22 @@ struct ARROW_EXPORT PadOptions : public FunctionOptions { std::string padding; }; -struct ARROW_EXPORT TrimOptions : public FunctionOptions { - explicit TrimOptions(std::string characters) : characters(std::move(characters)) {} +class ARROW_EXPORT TrimOptions : public FunctionOptions { + public: + explicit TrimOptions(std::string characters); + TrimOptions(); + constexpr static char const kTypeName[] = "trim"; /// The individual characters that can be trimmed from the string. std::string characters; }; -struct ARROW_EXPORT SliceOptions : public FunctionOptions { +class ARROW_EXPORT SliceOptions : public FunctionOptions { + public: explicit SliceOptions(int64_t start, int64_t stop = std::numeric_limits::max(), - int64_t step = 1) - : start(start), stop(stop), step(step) {} - + int64_t step = 1); + SliceOptions(); + constexpr static char const kTypeName[] = "slice"; int64_t start, stop, step; }; @@ -188,23 +218,21 @@ enum CompareOperator : int8_t { LESS_EQUAL, }; -struct CompareOptions : public FunctionOptions { - explicit CompareOptions(CompareOperator op) : op(op) {} - +class ARROW_EXPORT CompareOptions : public FunctionOptions { + public: + explicit CompareOptions(CompareOperator op); + CompareOptions(); + constexpr static char const kTypeName[] = "compare"; enum CompareOperator op; }; -struct ARROW_EXPORT ProjectOptions : public FunctionOptions { +class ARROW_EXPORT ProjectOptions : public FunctionOptions { + public: ProjectOptions(std::vector n, std::vector r, - std::vector> m) - : field_names(std::move(n)), - field_nullability(std::move(r)), - field_metadata(std::move(m)) {} - - explicit ProjectOptions(std::vector n) - : field_names(std::move(n)), - field_nullability(field_names.size(), true), - field_metadata(field_names.size(), NULLPTR) {} + std::vector> m); + explicit ProjectOptions(std::vector n); + ProjectOptions(); + constexpr static char const kTypeName[] = "project"; /// Names for wrapped columns std::vector field_names; @@ -348,8 +376,8 @@ Result MinElementWise( /// \since 1.0.0 /// \note API not yet finalized ARROW_EXPORT -Result Compare(const Datum& left, const Datum& right, - struct CompareOptions options, ExecContext* ctx = NULLPTR); +Result Compare(const Datum& left, const Datum& right, CompareOptions options, + ExecContext* ctx = NULLPTR); /// \brief Invert the values of a boolean datum /// \param[in] value datum to invert diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 0082d48112d..9c1ef8533b4 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -18,23 +18,139 @@ #include "arrow/compute/api_vector.h" #include +#include #include #include #include "arrow/array/array_nested.h" #include "arrow/array/builder_primitive.h" #include "arrow/compute/exec.h" +#include "arrow/compute/function_internal.h" +#include "arrow/compute/registry.h" #include "arrow/datum.h" #include "arrow/record_batch.h" #include "arrow/result.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/logging.h" namespace arrow { +using internal::checked_cast; using internal::checked_pointer_cast; +namespace internal { +using compute::DictionaryEncodeOptions; +using compute::FilterOptions; +template <> +struct EnumTraits + : BasicEnumTraits { + static std::string name() { return "FilterOptions::NullSelectionBehavior"; } + static std::string value_name(FilterOptions::NullSelectionBehavior value) { + switch (value) { + case FilterOptions::DROP: + return "DROP"; + case FilterOptions::EMIT_NULL: + return "EMIT_NULL"; + } + return ""; + } +}; +template <> +struct EnumTraits + : BasicEnumTraits { + static std::string name() { return "DictionaryEncodeOptions::NullEncodingBehavior"; } + static std::string value_name(DictionaryEncodeOptions::NullEncodingBehavior value) { + switch (value) { + case DictionaryEncodeOptions::ENCODE: + return "ENCODE"; + case DictionaryEncodeOptions::MASK: + return "MASK"; + } + return ""; + } +}; +} // namespace internal + namespace compute { +// ---------------------------------------------------------------------- +// Function options + +bool SortKey::Equals(const SortKey& other) const { + return name == other.name && order == other.order; +} +std::string SortKey::ToString() const { + std::stringstream ss; + ss << name << ' '; + switch (order) { + case SortOrder::Ascending: + ss << "ASC"; + break; + case SortOrder::Descending: + ss << "DESC"; + break; + } + return ss.str(); +} + +namespace internal { +namespace { +using ::arrow::internal::DataMember; +static auto kFilterOptionsType = GetFunctionOptionsType( + DataMember("null_selection_behavior", &FilterOptions::null_selection_behavior)); +static auto kTakeOptionsType = GetFunctionOptionsType( + DataMember("boundscheck", &TakeOptions::boundscheck)); +static auto kDictionaryEncodeOptionsType = + GetFunctionOptionsType(DataMember( + "null_encoding_behavior", &DictionaryEncodeOptions::null_encoding_behavior)); +static auto kArraySortOptionsType = GetFunctionOptionsType( + DataMember("order", &ArraySortOptions::order)); +static auto kSortOptionsType = + GetFunctionOptionsType(DataMember("sort_keys", &SortOptions::sort_keys)); +static auto kPartitionNthOptionsType = GetFunctionOptionsType( + DataMember("pivot", &PartitionNthOptions::pivot)); +} // namespace +} // namespace internal + +FilterOptions::FilterOptions(NullSelectionBehavior null_selection) + : FunctionOptions(internal::kFilterOptionsType), + null_selection_behavior(null_selection) {} +constexpr char FilterOptions::kTypeName[]; + +TakeOptions::TakeOptions(bool boundscheck) + : FunctionOptions(internal::kTakeOptionsType), boundscheck(boundscheck) {} +constexpr char TakeOptions::kTypeName[]; + +DictionaryEncodeOptions::DictionaryEncodeOptions(NullEncodingBehavior null_encoding) + : FunctionOptions(internal::kDictionaryEncodeOptionsType), + null_encoding_behavior(null_encoding) {} +constexpr char DictionaryEncodeOptions::kTypeName[]; + +ArraySortOptions::ArraySortOptions(SortOrder order) + : FunctionOptions(internal::kArraySortOptionsType), order(order) {} +constexpr char ArraySortOptions::kTypeName[]; + +SortOptions::SortOptions(std::vector sort_keys) + : FunctionOptions(internal::kSortOptionsType), sort_keys(std::move(sort_keys)) {} +constexpr char SortOptions::kTypeName[]; + +PartitionNthOptions::PartitionNthOptions(int64_t pivot) + : FunctionOptions(internal::kPartitionNthOptionsType), pivot(pivot) {} +constexpr char PartitionNthOptions::kTypeName[]; + +namespace internal { +void RegisterVectorOptions(FunctionRegistry* registry) { + DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kTakeOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kDictionaryEncodeOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kArraySortOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kSortOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kPartitionNthOptionsType)); +} +} // namespace internal + // ---------------------------------------------------------------------- // Direct exec interface to kernels diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index d67568e1567..2282b0098f9 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -32,7 +32,8 @@ class ExecContext; /// \addtogroup compute-concrete-options /// @{ -struct FilterOptions : public FunctionOptions { +class ARROW_EXPORT FilterOptions : public FunctionOptions { + public: /// Configure the action taken when a slot of the selection mask is null enum NullSelectionBehavior { /// the corresponding filtered value will be removed in the output @@ -41,30 +42,27 @@ struct FilterOptions : public FunctionOptions { EMIT_NULL, }; - explicit FilterOptions(NullSelectionBehavior null_selection = DROP) - : null_selection_behavior(null_selection) {} - + explicit FilterOptions(NullSelectionBehavior null_selection = DROP); + constexpr static char const kTypeName[] = "filter"; static FilterOptions Defaults() { return FilterOptions(); } NullSelectionBehavior null_selection_behavior = DROP; }; -struct ARROW_EXPORT TakeOptions : public FunctionOptions { - explicit TakeOptions(bool boundscheck = true) : boundscheck(boundscheck) {} - - bool boundscheck = true; +class ARROW_EXPORT TakeOptions : public FunctionOptions { + public: + explicit TakeOptions(bool boundscheck = true); + constexpr static char const kTypeName[] = "take"; static TakeOptions BoundsCheck() { return TakeOptions(true); } static TakeOptions NoBoundsCheck() { return TakeOptions(false); } static TakeOptions Defaults() { return BoundsCheck(); } -}; -enum class SortOrder { - Ascending, - Descending, + bool boundscheck = true; }; /// \brief Options for the dictionary encode function -struct DictionaryEncodeOptions : public FunctionOptions { +class ARROW_EXPORT DictionaryEncodeOptions : public FunctionOptions { + public: /// Configure how null values will be encoded enum NullEncodingBehavior { /// the null value will be added to the dictionary with a proper index @@ -73,44 +71,60 @@ struct DictionaryEncodeOptions : public FunctionOptions { MASK }; - explicit DictionaryEncodeOptions(NullEncodingBehavior null_encoding = MASK) - : null_encoding_behavior(null_encoding) {} - + explicit DictionaryEncodeOptions(NullEncodingBehavior null_encoding = MASK); + constexpr static char const kTypeName[] = "dictionary_encode"; static DictionaryEncodeOptions Defaults() { return DictionaryEncodeOptions(); } NullEncodingBehavior null_encoding_behavior = MASK; }; +enum class SortOrder { + Ascending, + Descending, +}; + /// \brief One sort key for PartitionNthIndices (TODO) and SortIndices -struct ARROW_EXPORT SortKey { +class ARROW_EXPORT SortKey : public util::EqualityComparable { + public: explicit SortKey(std::string name, SortOrder order = SortOrder::Ascending) : name(name), order(order) {} + using util::EqualityComparable::Equals; + using util::EqualityComparable::operator==; + using util::EqualityComparable::operator!=; + bool Equals(const SortKey& other) const; + std::string ToString() const; + /// The name of the sort column. std::string name; /// How to order by this sort key. SortOrder order; }; -struct ARROW_EXPORT ArraySortOptions : public FunctionOptions { - explicit ArraySortOptions(SortOrder order = SortOrder::Ascending) : order(order) {} - +class ARROW_EXPORT ArraySortOptions : public FunctionOptions { + public: + explicit ArraySortOptions(SortOrder order = SortOrder::Ascending); + constexpr static char const kTypeName[] = "array_sort"; static ArraySortOptions Defaults() { return ArraySortOptions{}; } SortOrder order; }; -struct ARROW_EXPORT SortOptions : public FunctionOptions { - explicit SortOptions(std::vector sort_keys = {}) : sort_keys(sort_keys) {} - +class ARROW_EXPORT SortOptions : public FunctionOptions { + public: + explicit SortOptions(std::vector sort_keys = {}); + constexpr static char const kTypeName[] = "sort"; static SortOptions Defaults() { return SortOptions{}; } std::vector sort_keys; }; /// \brief Partitioning options for NthToIndices -struct ARROW_EXPORT PartitionNthOptions : public FunctionOptions { - explicit PartitionNthOptions(int64_t pivot) : pivot(pivot) {} +class ARROW_EXPORT PartitionNthOptions : public FunctionOptions { + public: + explicit PartitionNthOptions(int64_t pivot); + PartitionNthOptions() : PartitionNthOptions(0) {} + constexpr static char const kTypeName[] = "partition_nth"; /// The index into the equivalent sorted array of the partition pivot element. int64_t pivot; diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index 8a091f2355d..521f217213d 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -18,6 +18,7 @@ #include "arrow/compute/cast.h" #include +#include #include #include #include @@ -26,10 +27,12 @@ #include "arrow/compute/cast_internal.h" #include "arrow/compute/exec.h" +#include "arrow/compute/function_internal.h" #include "arrow/compute/kernel.h" #include "arrow/compute/kernels/codegen_internal.h" #include "arrow/compute/registry.h" #include "arrow/util/logging.h" +#include "arrow/util/reflection_internal.h" namespace arrow { @@ -38,6 +41,9 @@ using internal::ToTypeName; namespace compute { namespace internal { +// ---------------------------------------------------------------------- +// Function options + namespace { std::unordered_map> g_cast_table; @@ -116,14 +122,35 @@ class CastMetaFunction : public MetaFunction { } }; +static auto kCastOptionsType = GetFunctionOptionsType( + arrow::internal::DataMember("to_type", &CastOptions::to_type), + arrow::internal::DataMember("allow_int_overflow", &CastOptions::allow_int_overflow), + arrow::internal::DataMember("allow_time_truncate", &CastOptions::allow_time_truncate), + arrow::internal::DataMember("allow_time_overflow", &CastOptions::allow_time_overflow), + arrow::internal::DataMember("allow_decimal_truncate", + &CastOptions::allow_decimal_truncate), + arrow::internal::DataMember("allow_float_truncate", + &CastOptions::allow_float_truncate), + arrow::internal::DataMember("allow_invalid_utf8", &CastOptions::allow_invalid_utf8)); } // namespace void RegisterScalarCast(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::make_shared())); + DCHECK_OK(registry->AddFunctionOptionsType(kCastOptionsType)); } - } // namespace internal +CastOptions::CastOptions(bool safe) + : FunctionOptions(internal::kCastOptionsType), + allow_int_overflow(!safe), + allow_time_truncate(!safe), + allow_time_overflow(!safe), + allow_decimal_truncate(!safe), + allow_float_truncate(!safe), + allow_invalid_utf8(!safe) {} + +constexpr char CastOptions::kTypeName[]; + CastFunction::CastFunction(std::string name, Type::type out_type_id) : ScalarFunction(std::move(name), Arity::Unary(), /*doc=*/nullptr), out_type_id_(out_type_id) {} diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index 818f2ef9182..8abd2a71bca 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -41,15 +41,11 @@ class ExecContext; /// \addtogroup compute-concrete-options /// @{ -struct ARROW_EXPORT CastOptions : public FunctionOptions { - explicit CastOptions(bool safe = true) - : allow_int_overflow(!safe), - allow_time_truncate(!safe), - allow_time_overflow(!safe), - allow_decimal_truncate(!safe), - allow_float_truncate(!safe), - allow_invalid_utf8(!safe) {} +class ARROW_EXPORT CastOptions : public FunctionOptions { + public: + explicit CastOptions(bool safe = true); + constexpr static char const kTypeName[] = "cast"; static CastOptions Safe(std::shared_ptr to_type = NULLPTR) { CastOptions safe(true); safe.to_type = std::move(to_type); diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index 7659442d8bf..cd95db2fd8c 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -44,7 +44,7 @@ class CpuInfo; namespace compute { -struct FunctionOptions; +class FunctionOptions; class FunctionRegistry; // It seems like 64K might be a good default chunksize to use for execution diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index 1c8c82de05e..aeabbf7bc5b 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -24,6 +24,7 @@ #include "arrow/compute/api_vector.h" #include "arrow/compute/exec/expression_internal.h" #include "arrow/compute/exec_internal.h" +#include "arrow/compute/function_internal.h" #include "arrow/io/memory.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" @@ -167,41 +168,14 @@ std::string Expression::ToString() const { out += arg.ToString() + ", "; } - if (call->options == nullptr) { + if (call->options) { + out += call->options->ToString(); + out.resize(out.size() + 1); + } else { out.resize(out.size() - 1); - out.back() = ')'; - return out; } - - if (auto options = GetSetLookupOptions(*call)) { - DCHECK_EQ(options->value_set.kind(), Datum::ARRAY); - out += "value_set=" + options->value_set.make_array()->ToString(); - if (options->skip_nulls) { - out += ", skip_nulls"; - } - return out + ")"; - } - - if (auto options = GetCastOptions(*call)) { - if (options->to_type == nullptr) { - return out + "to_type=)"; - } - out += "to_type=" + options->to_type->ToString(); - if (options->allow_int_overflow) out += ", allow_int_overflow"; - if (options->allow_time_truncate) out += ", allow_time_truncate"; - if (options->allow_time_overflow) out += ", allow_time_overflow"; - if (options->allow_decimal_truncate) out += ", allow_decimal_truncate"; - if (options->allow_float_truncate) out += ", allow_float_truncate"; - if (options->allow_invalid_utf8) out += ", allow_invalid_utf8"; - return out + ")"; - } - - if (auto options = GetStrptimeOptions(*call)) { - return out + "format=" + options->format + - ", unit=" + arrow::internal::ToString(options->unit) + ")"; - } - - return out + "{NON-REPRESENTABLE OPTIONS})"; + out.back() = ')'; + return out; } void PrintTo(const Expression& expr, std::ostream* os) { @@ -241,41 +215,9 @@ bool Expression::Equals(const Expression& other) const { } if (call->options == other_call->options) return true; - - if (auto options = GetSetLookupOptions(*call)) { - auto other_options = GetSetLookupOptions(*other_call); - return options->value_set == other_options->value_set && - options->skip_nulls == other_options->skip_nulls; + if (call->options && other_call->options) { + return call->options->Equals(other_call->options); } - - if (auto options = GetCastOptions(*call)) { - auto other_options = GetCastOptions(*other_call); - for (auto safety_opt : { - &compute::CastOptions::allow_int_overflow, - &compute::CastOptions::allow_time_truncate, - &compute::CastOptions::allow_time_overflow, - &compute::CastOptions::allow_decimal_truncate, - &compute::CastOptions::allow_float_truncate, - &compute::CastOptions::allow_invalid_utf8, - }) { - if (options->*safety_opt != other_options->*safety_opt) return false; - } - return options->to_type->Equals(other_options->to_type); - } - - if (auto options = GetProjectOptions(*call)) { - auto other_options = GetProjectOptions(*other_call); - return options->field_names == other_options->field_names; - } - - if (auto options = GetStrptimeOptions(*call)) { - auto other_options = GetStrptimeOptions(*other_call); - return options->format == other_options->format && - options->unit == other_options->unit; - } - - ARROW_LOG(WARNING) << "comparing unknown FunctionOptions for function " - << call->function_name; return false; } @@ -992,92 +934,6 @@ Result SimplifyWithGuarantee(Expression expr, return expr; } -namespace { - -Result> FunctionOptionsToStructScalar( - const Expression::Call& call) { - if (call.options == nullptr) { - return nullptr; - } - - if (auto options = GetSetLookupOptions(call)) { - if (!options->value_set.is_array()) { - return Status::NotImplemented("chunked value_set"); - } - return StructScalar::Make( - { - std::make_shared(options->value_set.make_array()), - MakeScalar(options->skip_nulls), - }, - {"value_set", "skip_nulls"}); - } - - if (auto options = GetCastOptions(call)) { - return StructScalar::Make( - { - MakeNullScalar(options->to_type), - MakeScalar(options->allow_int_overflow), - MakeScalar(options->allow_time_truncate), - MakeScalar(options->allow_time_overflow), - MakeScalar(options->allow_decimal_truncate), - MakeScalar(options->allow_float_truncate), - MakeScalar(options->allow_invalid_utf8), - }, - { - "to_type_holder", - "allow_int_overflow", - "allow_time_truncate", - "allow_time_overflow", - "allow_decimal_truncate", - "allow_float_truncate", - "allow_invalid_utf8", - }); - } - - return Status::NotImplemented("conversion of options for ", call.function_name); -} - -Status FunctionOptionsFromStructScalar(const StructScalar* repr, Expression::Call* call) { - if (repr == nullptr) { - call->options = nullptr; - return Status::OK(); - } - - if (IsSetLookup(call->function_name)) { - ARROW_ASSIGN_OR_RAISE(auto value_set, repr->field("value_set")); - ARROW_ASSIGN_OR_RAISE(auto skip_nulls, repr->field("skip_nulls")); - call->options = std::make_shared( - checked_cast(*value_set).value, - checked_cast(*skip_nulls).value); - return Status::OK(); - } - - if (call->function_name == "cast") { - auto options = std::make_shared(); - ARROW_ASSIGN_OR_RAISE(auto to_type_holder, repr->field("to_type_holder")); - options->to_type = to_type_holder->type; - - int i = 1; - for (bool* opt : { - &options->allow_int_overflow, - &options->allow_time_truncate, - &options->allow_time_overflow, - &options->allow_decimal_truncate, - &options->allow_float_truncate, - &options->allow_invalid_utf8, - }) { - *opt = checked_cast(*repr->value[i++]).value; - } - - call->options = std::move(options); - return Status::OK(); - } - - return Status::NotImplemented("conversion of options for ", call->function_name); -} - -} // namespace - // Serialization is accomplished by converting expressions to KeyValueMetadata and storing // this in the schema of a RecordBatch. Embedded arrays and scalars are stored in its // columns. Finally, the RecordBatch is written to an IPC file. @@ -1119,7 +975,8 @@ Result> Serialize(const Expression& expr) { } if (call->options) { - ARROW_ASSIGN_OR_RAISE(auto options_scalar, FunctionOptionsToStructScalar(*call)); + ARROW_ASSIGN_OR_RAISE(auto options_scalar, + internal::FunctionOptionsToStructScalar(*call->options)); ARROW_ASSIGN_OR_RAISE(auto value, AddScalar(*options_scalar)); metadata_->Append("options", std::move(value)); } @@ -1204,10 +1061,13 @@ Result Deserialize(std::shared_ptr buffer) { while (metadata().key(index_) != "end") { if (metadata().key(index_) == "options") { ARROW_ASSIGN_OR_RAISE(auto options_scalar, GetScalar(metadata().value(index_))); - auto expr = call(value, std::move(arguments)); - RETURN_NOT_OK(FunctionOptionsFromStructScalar( - checked_cast(options_scalar.get()), - const_cast(expr.call()))); + std::shared_ptr options; + if (options_scalar) { + ARROW_ASSIGN_OR_RAISE( + options, internal::FunctionOptionsFromStructScalar( + checked_cast(*options_scalar))); + } + auto expr = call(value, std::move(arguments), std::move(options)); index_ += 2; return expr; } diff --git a/cpp/src/arrow/compute/exec/expression_internal.h b/cpp/src/arrow/compute/exec/expression_internal.h index 7b0cc758f57..b9165a5f0c2 100644 --- a/cpp/src/arrow/compute/exec/expression_internal.h +++ b/cpp/src/arrow/compute/exec/expression_internal.h @@ -216,22 +216,11 @@ inline bool IsSetLookup(const std::string& function) { return function == "is_in" || function == "index_in"; } -inline const compute::SetLookupOptions* GetSetLookupOptions( - const Expression::Call& call) { - if (!IsSetLookup(call.function_name)) return nullptr; - return checked_cast(call.options.get()); -} - inline const compute::ProjectOptions* GetProjectOptions(const Expression::Call& call) { if (call.function_name != "project") return nullptr; return checked_cast(call.options.get()); } -inline const compute::StrptimeOptions* GetStrptimeOptions(const Expression::Call& call) { - if (call.function_name != "strptime") return nullptr; - return checked_cast(call.options.get()); -} - /// A helper for unboxing an Expression composed of associative function calls. /// Such expressions can frequently be rearranged to a semantically equivalent /// expression for more optimal execution or more straightforward manipulation. diff --git a/cpp/src/arrow/compute/exec/expression_test.cc b/cpp/src/arrow/compute/exec/expression_test.cc index 66212bf99d6..908e8962e43 100644 --- a/cpp/src/arrow/compute/exec/expression_test.cc +++ b/cpp/src/arrow/compute/exec/expression_test.cc @@ -27,6 +27,7 @@ #include #include "arrow/compute/exec/expression_internal.h" +#include "arrow/compute/function_internal.h" #include "arrow/compute/registry.h" #include "arrow/testing/gtest_util.h" @@ -184,17 +185,43 @@ TEST(Expression, ToString) { auto in_12 = call("index_in", {field_ref("beta")}, compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2]")}); - EXPECT_EQ(in_12.ToString(), "index_in(beta, value_set=[\n 1,\n 2\n])"); + EXPECT_EQ(in_12.ToString(), + "index_in(beta, {value_set=int32:[\n 1,\n 2\n], skip_nulls=false})"); EXPECT_EQ(and_(field_ref("a"), field_ref("b")).ToString(), "(a and b)"); EXPECT_EQ(or_(field_ref("a"), field_ref("b")).ToString(), "(a or b)"); EXPECT_EQ(not_(field_ref("a")).ToString(), "invert(a)"); - EXPECT_EQ(cast(field_ref("a"), int32()).ToString(), "cast(a, to_type=int32)"); - EXPECT_EQ(cast(field_ref("a"), nullptr).ToString(), - "cast(a, to_type=)"); - - struct WidgetifyOptions : compute::FunctionOptions { + EXPECT_EQ( + cast(field_ref("a"), int32()).ToString(), + "cast(a, {to_type=int32, allow_int_overflow=false, allow_time_truncate=false, " + "allow_time_overflow=false, allow_decimal_truncate=false, " + "allow_float_truncate=false, allow_invalid_utf8=false})"); + EXPECT_EQ( + cast(field_ref("a"), nullptr).ToString(), + "cast(a, {to_type=, allow_int_overflow=false, allow_time_truncate=false, " + "allow_time_overflow=false, allow_decimal_truncate=false, " + "allow_float_truncate=false, allow_invalid_utf8=false})"); + + class WidgetifyOptionsType : public FunctionOptionsType { + public: + static const FunctionOptionsType* GetInstance() { + static std::unique_ptr instance(new WidgetifyOptionsType()); + return instance.get(); + } + const char* type_name() const override { return "widgetify"; } + std::string Stringify(const FunctionOptions& options) const override { + return type_name(); + } + bool Compare(const FunctionOptions& options, + const FunctionOptions& other) const override { + return true; + } + }; + class WidgetifyOptions : public compute::FunctionOptions { + public: + explicit WidgetifyOptions(bool really = true) + : FunctionOptions(WidgetifyOptionsType::GetInstance()), really(really) {} bool really; }; @@ -202,7 +229,7 @@ TEST(Expression, ToString) { EXPECT_EQ(call("widgetify", {}).ToString(), "widgetif)"); EXPECT_EQ( call("widgetify", {literal(1)}, std::make_shared()).ToString(), - "widgetify(1, {NON-REPRESENTABLE OPTIONS})"); + "widgetify(1, widgetify)"); EXPECT_EQ(equal(field_ref("a"), literal(1)).ToString(), "(a == 1)"); EXPECT_EQ(less(field_ref("a"), literal(2)).ToString(), "(a < 2)"); diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index f2cd7d2a740..ae2c9446aa9 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -392,7 +392,7 @@ ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, int num_inputs, RecordBatchCollectNode* MakeRecordBatchCollectNode( ExecPlan* plan, std::string label, const std::shared_ptr& schema) { - return internal::checked_cast( + return arrow::internal::checked_cast( plan->EmplaceNode(plan, std::move(label), schema)); } diff --git a/cpp/src/arrow/compute/exec_test.cc b/cpp/src/arrow/compute/exec_test.cc index c56e6471c97..8ce7e52d252 100644 --- a/cpp/src/arrow/compute/exec_test.cc +++ b/cpp/src/arrow/compute/exec_test.cc @@ -31,6 +31,7 @@ #include "arrow/compute/exec.h" #include "arrow/compute/exec_internal.h" #include "arrow/compute/function.h" +#include "arrow/compute/function_internal.h" #include "arrow/compute/kernel.h" #include "arrow/compute/registry.h" #include "arrow/memory_pool.h" @@ -50,6 +51,10 @@ using internal::checked_cast; namespace compute { namespace detail { +using ::arrow::internal::BitmapEquals; +using ::arrow::internal::CopyBitmap; +using ::arrow::internal::CountSetBits; + TEST(ExecContext, BasicWorkings) { { ExecContext ctx; @@ -58,7 +63,7 @@ TEST(ExecContext, BasicWorkings) { ASSERT_EQ(std::numeric_limits::max(), ctx.exec_chunksize()); ASSERT_TRUE(ctx.use_threads()); - ASSERT_EQ(internal::CpuInfo::GetInstance(), ctx.cpu_info()); + ASSERT_EQ(arrow::internal::CpuInfo::GetInstance(), ctx.cpu_info()); } // Now, let's customize all the things @@ -277,9 +282,9 @@ TEST_F(TestPropagateNulls, SingleValueWithNulls) { ASSERT_EQ(arr->Slice(offset)->null_count(), output.GetNullCount()); - ASSERT_TRUE(internal::BitmapEquals(output.buffers[0]->data(), output.offset, - sliced->null_bitmap_data(), sliced->offset(), - output.length)); + ASSERT_TRUE(BitmapEquals(output.buffers[0]->data(), output.offset, + sliced->null_bitmap_data(), sliced->offset(), + output.length)); AssertValidityZeroExtraBits(output); }; @@ -372,8 +377,8 @@ TEST_F(TestPropagateNulls, IntersectsNulls) { const auto& out_buffer = *output.buffers[0]; - ASSERT_TRUE(internal::BitmapEquals(out_buffer.data(), output_offset, ex_bitmap, - /*ex_offset=*/0, length)); + ASSERT_TRUE(BitmapEquals(out_buffer.data(), output_offset, ex_bitmap, + /*ex_offset=*/0, length)); // Now check that the rest of the bits in out_buffer are still 0 AssertValidityZeroExtraBits(output); @@ -556,15 +561,14 @@ Status ExecComputedBitmap(KernelContext* ctx, const ExecBatch& batch, Datum* out const ArrayData& arg0 = *batch[0].array(); ArrayData* out_arr = out->mutable_array(); - if (internal::CountSetBits(arg0.buffers[0]->data(), arg0.offset, batch.length) > 0) { + if (CountSetBits(arg0.buffers[0]->data(), arg0.offset, batch.length) > 0) { // Check that the bitmap has not been already copied over - DCHECK(!internal::BitmapEquals(arg0.buffers[0]->data(), arg0.offset, - out_arr->buffers[0]->data(), out_arr->offset, - batch.length)); + DCHECK(!BitmapEquals(arg0.buffers[0]->data(), arg0.offset, + out_arr->buffers[0]->data(), out_arr->offset, batch.length)); } - internal::CopyBitmap(arg0.buffers[0]->data(), arg0.offset, batch.length, - out_arr->buffers[0]->mutable_data(), out_arr->offset); + CopyBitmap(arg0.buffers[0]->data(), arg0.offset, batch.length, + out_arr->buffers[0]->mutable_data(), out_arr->offset); return ExecCopy(ctx, batch, out); } @@ -587,16 +591,33 @@ Status ExecNoPreallocatedAnything(KernelContext* ctx, const ExecBatch& batch, Status s = (ctx->AllocateBitmap(out_arr->length).Value(&out_arr->buffers[0])); DCHECK_OK(s); const ArrayData& arg0 = *batch[0].array(); - internal::CopyBitmap(arg0.buffers[0]->data(), arg0.offset, batch.length, - out_arr->buffers[0]->mutable_data(), /*offset=*/0); + CopyBitmap(arg0.buffers[0]->data(), arg0.offset, batch.length, + out_arr->buffers[0]->mutable_data(), /*offset=*/0); // Reuse the kernel that allocates the data return ExecNoPreallocatedData(ctx, batch, out); } -struct ExampleOptions : public FunctionOptions { +class ExampleOptionsType : public FunctionOptionsType { + public: + static const FunctionOptionsType* GetInstance() { + static std::unique_ptr instance(new ExampleOptionsType()); + return instance.get(); + } + const char* type_name() const override { return "example"; } + std::string Stringify(const FunctionOptions& options) const override { + return type_name(); + } + bool Compare(const FunctionOptions& options, + const FunctionOptions& other) const override { + return true; + } +}; +class ExampleOptions : public FunctionOptions { + public: + explicit ExampleOptions(std::shared_ptr value) + : FunctionOptions(ExampleOptionsType::GetInstance()), value(std::move(value)) {} std::shared_ptr value; - explicit ExampleOptions(std::shared_ptr value) : value(std::move(value)) {} }; struct ExampleState : public KernelState { diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index 0f94baaedfc..05d14d03b16 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -21,10 +21,13 @@ #include #include +#include "arrow/compute/api_scalar.h" #include "arrow/compute/cast.h" #include "arrow/compute/exec.h" #include "arrow/compute/exec_internal.h" +#include "arrow/compute/function_internal.h" #include "arrow/compute/kernels/common.h" +#include "arrow/compute/registry.h" #include "arrow/datum.h" #include "arrow/util/cpu_info.h" @@ -33,6 +36,38 @@ namespace arrow { using internal::checked_cast; namespace compute { +Result> FunctionOptionsType::Serialize( + const FunctionOptions&) const { + return Status::NotImplemented("Serialize for ", type_name()); +} + +Result> FunctionOptionsType::Deserialize( + const Buffer& buffer) const { + return Status::NotImplemented("Deserialize for ", type_name()); +} + +std::string FunctionOptions::ToString() const { return options_type()->Stringify(*this); } + +bool FunctionOptions::Equals(const FunctionOptions& other) const { + if (this == &other) return true; + if (options_type() != other.options_type()) return false; + return options_type()->Compare(*this, other); +} + +Result> FunctionOptions::Serialize() const { + return options_type()->Serialize(*this); +} + +Result> FunctionOptions::Deserialize( + const std::string& type_name, const Buffer& buffer) { + ARROW_ASSIGN_OR_RAISE(auto options, + GetFunctionRegistry()->GetFunctionOptionsType(type_name)); + return options->Deserialize(buffer); +} + +void PrintTo(const FunctionOptions& options, std::ostream* os) { + *os << options.ToString(); +} static const FunctionDoc kEmptyFunctionDoc{}; diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h index 9a3e1c1852f..bd854bbb28e 100644 --- a/cpp/src/arrow/compute/function.h +++ b/cpp/src/arrow/compute/function.h @@ -29,6 +29,7 @@ #include "arrow/datum.h" #include "arrow/result.h" #include "arrow/status.h" +#include "arrow/util/compare.h" #include "arrow/util/macros.h" #include "arrow/util/visibility.h" @@ -39,12 +40,50 @@ namespace compute { /// /// @{ +/// \brief Extension point for defining options outside libarrow (but +/// still within this project). +class ARROW_EXPORT FunctionOptionsType { + public: + virtual ~FunctionOptionsType() = default; + + virtual const char* type_name() const = 0; + virtual std::string Stringify(const FunctionOptions&) const = 0; + virtual bool Compare(const FunctionOptions&, const FunctionOptions&) const = 0; + virtual Result> Serialize(const FunctionOptions&) const; + virtual Result> Deserialize( + const Buffer& buffer) const; +}; + /// \brief Base class for specifying options configuring a function's behavior, /// such as error handling. -struct ARROW_EXPORT FunctionOptions { +class ARROW_EXPORT FunctionOptions : public util::EqualityComparable { + public: virtual ~FunctionOptions() = default; + + const FunctionOptionsType* options_type() const { return options_type_; } + const char* type_name() const { return options_type()->type_name(); } + + bool Equals(const FunctionOptions& other) const; + using util::EqualityComparable::Equals; + using util::EqualityComparable::operator==; + using util::EqualityComparable::operator!=; + std::string ToString() const; + /// \brief Serialize an options struct to a buffer. + Result> Serialize() const; + /// \brief Deserialize an options struct from a buffer. + /// Note: this will only look for `type_name` in the default FunctionRegistry; + /// to use a custom FunctionRegistry, look up the FunctionOptionsType, then + /// call FunctionOptionsType::Deserialize(). + static Result> Deserialize( + const std::string& type_name, const Buffer& buffer); + + protected: + explicit FunctionOptions(const FunctionOptionsType* type) : options_type_(type) {} + const FunctionOptionsType* options_type_; }; +ARROW_EXPORT void PrintTo(const FunctionOptions&, std::ostream*); + /// \brief Contains the number of required arguments for the function. /// /// Naming conventions taken from https://en.wikipedia.org/wiki/Arity. diff --git a/cpp/src/arrow/compute/function_internal.cc b/cpp/src/arrow/compute/function_internal.cc new file mode 100644 index 00000000000..5234a421a7e --- /dev/null +++ b/cpp/src/arrow/compute/function_internal.cc @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/function_internal.h" + +#include "arrow/array/util.h" +#include "arrow/compute/function.h" +#include "arrow/compute/registry.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/scalar.h" +#include "arrow/util/checked_cast.h" + +namespace arrow { +namespace compute { +namespace internal { +using ::arrow::internal::checked_cast; + +constexpr char kTypeNameField[] = "_type_name"; + +Result> FunctionOptionsToStructScalar( + const FunctionOptions& options) { + std::vector field_names; + std::vector> values; + const auto* options_type = + checked_cast(options.options_type()); + RETURN_NOT_OK(options_type->ToStructScalar(options, &field_names, &values)); + field_names.push_back(kTypeNameField); + const char* options_name = options.type_name(); + values.emplace_back( + new BinaryScalar(Buffer::Wrap(options_name, std::strlen(options_name)))); + return StructScalar::Make(std::move(values), std::move(field_names)); +} + +Result> FunctionOptionsFromStructScalar( + const StructScalar& scalar) { + ARROW_ASSIGN_OR_RAISE(auto type_name_holder, scalar.field(kTypeNameField)); + const std::string type_name = + checked_cast(*type_name_holder).value->ToString(); + ARROW_ASSIGN_OR_RAISE(auto raw_options_type, + GetFunctionRegistry()->GetFunctionOptionsType(type_name)); + const auto* options_type = checked_cast(raw_options_type); + return options_type->FromStructScalar(scalar); +} + +Result> GenericOptionsType::Serialize( + const FunctionOptions& options) const { + ARROW_ASSIGN_OR_RAISE(auto scalar, FunctionOptionsToStructScalar(options)); + ARROW_ASSIGN_OR_RAISE(auto array, MakeArrayFromScalar(*scalar, 1)); + auto batch = + RecordBatch::Make(schema({field("", array->type())}), /*num_rows=*/1, {array}); + ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create()); + ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeFileWriter(stream, batch->schema())); + RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); + RETURN_NOT_OK(writer->Close()); + return stream->Finish(); +} + +Result> GenericOptionsType::Deserialize( + const Buffer& buffer) const { + return DeserializeFunctionOptions(buffer); +} + +Result> DeserializeFunctionOptions( + const Buffer& buffer) { + io::BufferReader stream(buffer); + ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(&stream)); + ARROW_ASSIGN_OR_RAISE(auto batch, reader->ReadRecordBatch(0)); + if (batch->num_rows() != 1) { + return Status::Invalid( + "serialized FunctionOptions's batch repr was not a single row - had ", + batch->num_rows()); + } + if (batch->num_columns() != 1) { + return Status::Invalid( + "serialized FunctionOptions's batch repr was not a single column - had ", + batch->num_columns()); + } + auto column = batch->column(0); + if (column->type()->id() != Type::STRUCT) { + return Status::Invalid( + "serialized FunctionOptions's batch repr was not a struct column - was ", + column->type()->ToString()); + } + ARROW_ASSIGN_OR_RAISE(auto raw_scalar, + checked_cast(*column).GetScalar(0)); + auto scalar = checked_cast(*raw_scalar); + return FunctionOptionsFromStructScalar(scalar); +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/function_internal.h b/cpp/src/arrow/compute/function_internal.h new file mode 100644 index 00000000000..fdd7f09ba1f --- /dev/null +++ b/cpp/src/arrow/compute/function_internal.h @@ -0,0 +1,626 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/array/builder_base.h" +#include "arrow/array/builder_binary.h" +#include "arrow/array/builder_nested.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/function.h" +#include "arrow/compute/type_fwd.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/key_value_metadata.h" +#include "arrow/util/reflection_internal.h" +#include "arrow/util/string.h" +#include "arrow/util/visibility.h" + +namespace arrow { +struct Scalar; +struct StructScalar; +using ::arrow::internal::checked_cast; + +namespace internal { +template <> +struct EnumTraits + : BasicEnumTraits { + static std::string name() { return "SortOrder"; } + static std::string value_name(compute::SortOrder value) { + switch (value) { + case compute::SortOrder::Ascending: + return "Ascending"; + case compute::SortOrder::Descending: + return "Descending"; + } + return ""; + } +}; +} // namespace internal + +namespace compute { +namespace internal { + +using arrow::internal::EnumTraits; +using arrow::internal::has_enum_traits; + +template ::type> +Result ValidateEnumValue(CType raw) { + for (auto valid : EnumTraits::values()) { + if (raw == static_cast(valid)) { + return static_cast(raw); + } + } + return Status::Invalid("Invalid value for ", EnumTraits::name(), ": ", raw); +} + +class GenericOptionsType : public FunctionOptionsType { + public: + Result> Serialize(const FunctionOptions&) const override; + Result> Deserialize( + const Buffer& buffer) const override; + virtual Status ToStructScalar(const FunctionOptions& options, + std::vector* field_names, + std::vector>* values) const = 0; + virtual Result> FromStructScalar( + const StructScalar& scalar) const = 0; +}; + +ARROW_EXPORT +Result> FunctionOptionsToStructScalar( + const FunctionOptions&); +ARROW_EXPORT +Result> FunctionOptionsFromStructScalar( + const StructScalar&); +ARROW_EXPORT +Result> DeserializeFunctionOptions(const Buffer& buffer); + +template +static inline enable_if_t::value, std::string> GenericToString( + const T& value) { + std::stringstream ss; + ss << value; + return ss.str(); +} + +static inline std::string GenericToString(bool value) { return value ? "true" : "false"; } + +static inline std::string GenericToString(const std::string& value) { + std::stringstream ss; + ss << '"' << value << '"'; + return ss.str(); +} + +template +static inline enable_if_t::value, std::string> GenericToString( + const T value) { + return EnumTraits::value_name(value); +} + +template +static inline std::string GenericToString(const std::shared_ptr& value) { + std::stringstream ss; + return value ? value->ToString() : ""; +} + +static inline std::string GenericToString(const std::shared_ptr& value) { + std::stringstream ss; + ss << value->type->ToString() << ":" << value->ToString(); + return ss.str(); +} + +static inline std::string GenericToString( + const std::shared_ptr& value) { + std::stringstream ss; + ss << "KeyValueMetadata{"; + if (value) { + bool first = true; + for (const auto& pair : value->sorted_pairs()) { + if (!first) ss << ", "; + first = false; + ss << pair.first << ':' << pair.second; + } + } + ss << '}'; + return ss.str(); +} + +static inline std::string GenericToString(const Datum& value) { + switch (value.kind()) { + case Datum::NONE: + return ""; + case Datum::SCALAR: + return GenericToString(value.scalar()); + case Datum::ARRAY: { + std::stringstream ss; + ss << value.type()->ToString() << ':' << value.make_array()->ToString(); + return ss.str(); + } + case Datum::CHUNKED_ARRAY: + case Datum::RECORD_BATCH: + case Datum::TABLE: + case Datum::COLLECTION: + return value.ToString(); + } + return value.ToString(); +} + +template +static inline std::string GenericToString(const std::vector& value) { + std::stringstream ss; + ss << "["; + bool first = true; + // Don't use range-for with auto& to avoid Clang -Wrange-loop-analysis + for (auto it = value.begin(); it != value.end(); it++) { + if (!first) ss << ", "; + first = false; + ss << GenericToString(*it); + } + ss << ']'; + return ss.str(); +} + +static inline std::string GenericToString(SortOrder value) { + switch (value) { + case SortOrder::Ascending: + return "Ascending"; + case SortOrder::Descending: + return "Descending"; + } + return ""; +} + +static inline std::string GenericToString(const std::vector& value) { + std::stringstream ss; + ss << '['; + bool first = true; + for (const auto& key : value) { + if (!first) { + ss << ", "; + } + first = false; + ss << key.ToString(); + } + ss << ']'; + return ss.str(); +} + +template +static inline bool GenericEquals(const T& left, const T& right) { + return left == right; +} + +template +static inline bool GenericEquals(const std::shared_ptr& left, + const std::shared_ptr& right) { + if (left && right) { + return left->Equals(*right); + } + return left == right; +} + +static inline bool IsEmpty(const std::shared_ptr& meta) { + return !meta || meta->size() == 0; +} + +static inline bool GenericEquals(const std::shared_ptr& left, + const std::shared_ptr& right) { + // Special case since null metadata is considered equivalent to empty + if (IsEmpty(left) || IsEmpty(right)) { + return IsEmpty(left) && IsEmpty(right); + } + return left->Equals(*right); +} + +template +static inline bool GenericEquals(const std::vector& left, + const std::vector& right) { + if (left.size() != right.size()) return false; + for (size_t i = 0; i < left.size(); i++) { + if (!GenericEquals(left[i], right[i])) return false; + } + return true; +} + +template +static inline decltype(TypeTraits::ArrowType>::type_singleton()) +GenericTypeSingleton() { + return TypeTraits::ArrowType>::type_singleton(); +} + +template +static inline enable_if_same, + std::shared_ptr> +GenericTypeSingleton() { + return map(binary(), binary()); +} + +template +static inline enable_if_t::value, std::shared_ptr> +GenericTypeSingleton() { + return TypeTraits::Type>::type_singleton(); +} + +template +static inline enable_if_same> +GenericTypeSingleton() { + std::vector> fields; + fields.emplace_back(new Field("name", GenericTypeSingleton())); + fields.emplace_back(new Field("order", GenericTypeSingleton())); + return std::make_shared(std::move(fields)); +} + +// N.B. ordering of overloads is relatively fragile +template +static inline Result()))> GenericToScalar( + const T& value) { + return MakeScalar(value); +} + +// For Clang/libc++: when iterating through vector, we can't +// pass it by reference so the overload above doesn't apply +static inline Result> GenericToScalar(bool value) { + return MakeScalar(value); +} + +template ::value>> +static inline Result> GenericToScalar(const T value) { + using CType = typename EnumTraits::CType; + return GenericToScalar(static_cast(value)); +} + +static inline Result> GenericToScalar(const SortKey& value) { + ARROW_ASSIGN_OR_RAISE(auto name, GenericToScalar(value.name)); + ARROW_ASSIGN_OR_RAISE(auto order, GenericToScalar(value.order)); + return StructScalar::Make({name, order}, {"name", "order"}); +} + +static inline Result> GenericToScalar( + const std::shared_ptr& value) { + auto ty = GenericTypeSingleton>(); + std::unique_ptr builder; + RETURN_NOT_OK(MakeBuilder(default_memory_pool(), ty, &builder)); + auto* map_builder = checked_cast(builder.get()); + auto* key_builder = checked_cast(map_builder->key_builder()); + auto* item_builder = checked_cast(map_builder->item_builder()); + RETURN_NOT_OK(map_builder->Append()); + if (value) { + RETURN_NOT_OK(key_builder->AppendValues(value->keys())); + RETURN_NOT_OK(item_builder->AppendValues(value->values())); + } + std::shared_ptr arr; + RETURN_NOT_OK(map_builder->Finish(&arr)); + return arr->GetScalar(0); +} + +template +static inline Result> GenericToScalar( + const std::vector& value) { + std::shared_ptr type = GenericTypeSingleton(); + std::vector> scalars; + scalars.reserve(value.size()); + // Don't use range-for with auto& to avoid Clang -Wrange-loop-analysis + for (auto it = value.begin(); it != value.end(); it++) { + ARROW_ASSIGN_OR_RAISE(auto scalar, GenericToScalar(*it)); + scalars.push_back(std::move(scalar)); + } + std::unique_ptr builder; + RETURN_NOT_OK( + MakeBuilder(default_memory_pool(), type ? type : scalars[0]->type, &builder)); + RETURN_NOT_OK(builder->AppendScalars(scalars)); + std::shared_ptr out; + RETURN_NOT_OK(builder->Finish(&out)); + return std::make_shared(std::move(out)); +} + +static inline Result> GenericToScalar( + const std::shared_ptr& value) { + if (!value) { + return Status::Invalid("shared_ptr is nullptr"); + } + return MakeNullScalar(value); +} + +static inline Result> GenericToScalar( + const std::shared_ptr& value) { + return value; +} + +static inline Result> GenericToScalar( + const std::shared_ptr& value) { + return std::make_shared(value); +} + +static inline Result> GenericToScalar(const Datum& value) { + // TODO(ARROW-9434): store in a union instead. + switch (value.kind()) { + case Datum::ARRAY: + return GenericToScalar(value.make_array()); + break; + default: + return Status::NotImplemented("Cannot serialize Datum kind ", value.kind()); + } +} + +template +static inline enable_if_primitive_ctype::ArrowType, Result> +GenericFromScalar(const std::shared_ptr& value) { + using ArrowType = typename CTypeTraits::ArrowType; + using ScalarType = typename TypeTraits::ScalarType; + if (value->type->id() != ArrowType::type_id) { + return Status::Invalid("Expected type ", ArrowType::type_id, " but got ", + value->type->ToString()); + } + const auto& holder = checked_cast(*value); + if (!holder.is_valid) return Status::Invalid("Got null scalar"); + return holder.value; +} + +template +static inline enable_if_primitive_ctype::Type, Result> +GenericFromScalar(const std::shared_ptr& value) { + ARROW_ASSIGN_OR_RAISE(auto raw_val, + GenericFromScalar::CType>(value)); + return ValidateEnumValue(raw_val); +} + +template +using enable_if_same_result = enable_if_same>; + +template +static inline enable_if_same_result GenericFromScalar( + const std::shared_ptr& value) { + if (!is_base_binary_like(value->type->id())) { + return Status::Invalid("Expected binary-like type but got ", value->type->ToString()); + } + const auto& holder = checked_cast(*value); + if (!holder.is_valid) return Status::Invalid("Got null scalar"); + return holder.value->ToString(); +} + +template +static inline enable_if_same_result GenericFromScalar( + const std::shared_ptr& value) { + if (value->type->id() != Type::STRUCT) { + return Status::Invalid("Expected type STRUCT but got ", value->type->id()); + } + if (!value->is_valid) return Status::Invalid("Got null scalar"); + const auto& holder = checked_cast(*value); + ARROW_ASSIGN_OR_RAISE(auto name_holder, holder.field("name")); + ARROW_ASSIGN_OR_RAISE(auto order_holder, holder.field("order")); + ARROW_ASSIGN_OR_RAISE(auto name, GenericFromScalar(name_holder)); + ARROW_ASSIGN_OR_RAISE(auto order, GenericFromScalar(order_holder)); + return SortKey{std::move(name), order}; +} + +template +static inline enable_if_same_result> GenericFromScalar( + const std::shared_ptr& value) { + return value->type; +} + +template +static inline enable_if_same_result> GenericFromScalar( + const std::shared_ptr& value) { + return value; +} + +template +static inline enable_if_same_result> +GenericFromScalar(const std::shared_ptr& value) { + auto ty = GenericTypeSingleton>(); + if (!value->type->Equals(ty)) { + return Status::Invalid("Expected ", ty->ToString(), " but got ", + value->type->ToString()); + } + const auto& holder = checked_cast(*value); + std::vector keys; + std::vector values; + const auto& list = checked_cast(*holder.value); + const auto& key_arr = checked_cast(*list.field(0)); + const auto& value_arr = checked_cast(*list.field(1)); + for (int64_t i = 0; i < list.length(); i++) { + keys.push_back(key_arr.GetString(i)); + values.push_back(value_arr.GetString(i)); + } + return key_value_metadata(std::move(keys), std::move(values)); +} + +template +static inline enable_if_same_result GenericFromScalar( + const std::shared_ptr& value) { + if (value->type->id() == Type::LIST) { + const auto& holder = checked_cast(*value); + return holder.value; + } + // TODO(ARROW-9434): handle other possible datum kinds by looking for a union + return Status::Invalid("Cannot deserialize Datum from ", value->ToString()); +} + +template +static enable_if_same::ArrowType, ListType, Result> +GenericFromScalar(const std::shared_ptr& value) { + using ValueType = typename T::value_type; + if (value->type->id() != Type::LIST) { + return Status::Invalid("Expected type LIST but got ", value->type->ToString()); + } + const auto& holder = checked_cast(*value); + if (!holder.is_valid) return Status::Invalid("Got null scalar"); + std::vector result; + for (int i = 0; i < holder.value->length(); i++) { + ARROW_ASSIGN_OR_RAISE(auto scalar, holder.value->GetScalar(i)); + ARROW_ASSIGN_OR_RAISE(auto v, GenericFromScalar(scalar)); + result.push_back(std::move(v)); + } + return result; +} + +template +struct StringifyImpl { + template + StringifyImpl(const Options& obj, const Tuple& props) + : obj_(obj), members_(props.size()) { + props.ForEach(*this); + } + + template + void operator()(const Property& prop, size_t i) { + std::stringstream ss; + ss << prop.name() << '=' << GenericToString(prop.get(obj_)); + members_[i] = ss.str(); + } + + std::string Finish() { + return "{" + arrow::internal::JoinStrings(members_, ", ") + "}"; + } + + const Options& obj_; + std::vector members_; +}; + +template +struct CompareImpl { + template + CompareImpl(const Options& l, const Options& r, const Tuple& props) + : left_(l), right_(r) { + props.ForEach(*this); + } + + template + void operator()(const Property& prop, size_t) { + equal_ &= GenericEquals(prop.get(left_), prop.get(right_)); + } + + const Options& left_; + const Options& right_; + bool equal_ = true; +}; + +template +struct ToStructScalarImpl { + template + ToStructScalarImpl(const Options& obj, const Tuple& props, + std::vector* field_names, + std::vector>* values) + : obj_(obj), field_names_(field_names), values_(values) { + props.ForEach(*this); + } + + template + void operator()(const Property& prop, size_t) { + if (!status_.ok()) return; + auto result = GenericToScalar(prop.get(obj_)); + if (!result.ok()) { + status_ = result.status().WithMessage("Could not serialize field ", prop.name(), + " of options type ", Options::kTypeName, ": ", + result.status().message()); + return; + } + field_names_->emplace_back(prop.name()); + values_->push_back(result.MoveValueUnsafe()); + } + + const Options& obj_; + Status status_; + std::vector* field_names_; + std::vector>* values_; +}; + +template +struct FromStructScalarImpl { + template + FromStructScalarImpl(Options* obj, const StructScalar& scalar, const Tuple& props) + : obj_(obj), scalar_(scalar) { + props.ForEach(*this); + } + + template + void operator()(const Property& prop, size_t) { + if (!status_.ok()) return; + auto maybe_holder = scalar_.field(std::string(prop.name())); + if (!maybe_holder.ok()) { + status_ = maybe_holder.status().WithMessage( + "Cannot deserialize field ", prop.name(), " of options type ", + Options::kTypeName, ": ", maybe_holder.status().message()); + return; + } + auto holder = maybe_holder.MoveValueUnsafe(); + auto result = GenericFromScalar(holder); + if (!result.ok()) { + status_ = result.status().WithMessage("Cannot deserialize field ", prop.name(), + " of options type ", Options::kTypeName, ": ", + result.status().message()); + return; + } + prop.set(obj_, result.MoveValueUnsafe()); + } + + Options* obj_; + Status status_; + const StructScalar& scalar_; +}; + +template +const FunctionOptionsType* GetFunctionOptionsType(const Properties&... properties) { + static const class OptionsType : public GenericOptionsType { + public: + explicit OptionsType(const arrow::internal::PropertyTuple properties) + : properties_(properties) {} + + const char* type_name() const override { return Options::kTypeName; } + + std::string Stringify(const FunctionOptions& options) const override { + const auto& self = checked_cast(options); + return StringifyImpl(self, properties_).Finish(); + } + bool Compare(const FunctionOptions& options, + const FunctionOptions& other) const override { + const auto& lhs = checked_cast(options); + const auto& rhs = checked_cast(other); + return CompareImpl(lhs, rhs, properties_).equal_; + } + Status ToStructScalar(const FunctionOptions& options, + std::vector* field_names, + std::vector>* values) const override { + const auto& self = checked_cast(options); + RETURN_NOT_OK( + ToStructScalarImpl(self, properties_, field_names, values).status_); + return Status::OK(); + } + Result> FromStructScalar( + const StructScalar& scalar) const override { + auto options = std::unique_ptr(new Options()); + RETURN_NOT_OK( + FromStructScalarImpl(options.get(), scalar, properties_).status_); + return std::move(options); + } + + private: + const arrow::internal::PropertyTuple properties_; + } instance(arrow::internal::MakeProperties(properties...)); + return &instance; +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/function_test.cc b/cpp/src/arrow/compute/function_test.cc index 581555e931f..4c42ce39600 100644 --- a/cpp/src/arrow/compute/function_test.cc +++ b/cpp/src/arrow/compute/function_test.cc @@ -21,16 +21,113 @@ #include +#include "arrow/compute/api_aggregate.h" +#include "arrow/compute/api_scalar.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/cast.h" #include "arrow/compute/function.h" #include "arrow/compute/kernel.h" #include "arrow/datum.h" #include "arrow/status.h" #include "arrow/testing/gtest_util.h" #include "arrow/type.h" +#include "arrow/util/key_value_metadata.h" namespace arrow { namespace compute { +TEST(FunctionOptions, Equality) { + std::vector> options; + options.emplace_back(new ScalarAggregateOptions()); + options.emplace_back(new ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1)); + options.emplace_back(new ModeOptions()); + options.emplace_back(new ModeOptions(/*n=*/2)); + options.emplace_back(new VarianceOptions()); + options.emplace_back(new VarianceOptions(/*ddof=*/2)); + options.emplace_back(new QuantileOptions()); + options.emplace_back( + new QuantileOptions(/*q=*/0.75, QuantileOptions::Interpolation::MIDPOINT)); + options.emplace_back(new TDigestOptions()); + options.emplace_back( + new TDigestOptions(/*q=*/0.75, /*delta=*/50, /*buffer_size=*/1024)); + options.emplace_back(new IndexOptions(ScalarFromJSON(int64(), "16"))); + options.emplace_back(new IndexOptions(ScalarFromJSON(boolean(), "true"))); + options.emplace_back(new IndexOptions(ScalarFromJSON(boolean(), "null"))); + options.emplace_back(new ElementWiseAggregateOptions()); + options.emplace_back(new ElementWiseAggregateOptions(/*skip_nulls=*/false)); + options.emplace_back(new JoinOptions()); + options.emplace_back(new JoinOptions(JoinOptions::REPLACE, "replacement")); + options.emplace_back(new MatchSubstringOptions("pattern")); + options.emplace_back(new MatchSubstringOptions("pattern", /*ignore_case=*/true)); + options.emplace_back(new SplitOptions()); + options.emplace_back(new SplitOptions(/*max_splits=*/2, /*reverse=*/true)); + options.emplace_back(new SplitPatternOptions("pattern")); + options.emplace_back( + new SplitPatternOptions("pattern", /*max_splits=*/2, /*reverse=*/true)); + options.emplace_back(new ReplaceSubstringOptions("pattern", "replacement")); + options.emplace_back( + new ReplaceSubstringOptions("pattern", "replacement", /*max_replacements=*/2)); + options.emplace_back(new ReplaceSliceOptions(0, 1, "foo")); + options.emplace_back(new ReplaceSliceOptions(1, -1, "bar")); + options.emplace_back(new ExtractRegexOptions("pattern")); + options.emplace_back(new ExtractRegexOptions("pattern2")); + options.emplace_back(new SetLookupOptions(ArrayFromJSON(int64(), "[1, 2, 3, 4]"))); + options.emplace_back(new SetLookupOptions(ArrayFromJSON(boolean(), "[true, false]"))); + options.emplace_back(new StrptimeOptions("%Y", TimeUnit::type::MILLI)); + options.emplace_back(new StrptimeOptions("%Y", TimeUnit::type::NANO)); + options.emplace_back(new PadOptions(5, " ")); + options.emplace_back(new PadOptions(10, "A")); + options.emplace_back(new TrimOptions(" ")); + options.emplace_back(new TrimOptions("abc")); + options.emplace_back(new SliceOptions(/*start=*/1)); + options.emplace_back(new SliceOptions(/*start=*/1, /*stop=*/-5, /*step=*/-2)); + options.emplace_back(new CompareOptions(CompareOperator::EQUAL)); + options.emplace_back(new CompareOptions(CompareOperator::LESS)); + // N.B. we never actually use field_nullability or field_metadata in Arrow + options.emplace_back(new ProjectOptions({"col1"}, {true}, {})); + options.emplace_back(new ProjectOptions({"col1"}, {false}, {})); + options.emplace_back( + new ProjectOptions({"col1"}, {false}, {key_value_metadata({{"key", "val"}})})); + options.emplace_back(new CastOptions(CastOptions::Safe(boolean()))); + options.emplace_back(new CastOptions(CastOptions::Unsafe(int64()))); + options.emplace_back(new FilterOptions()); + options.emplace_back( + new FilterOptions(FilterOptions::NullSelectionBehavior::EMIT_NULL)); + options.emplace_back(new TakeOptions()); + options.emplace_back(new TakeOptions(/*boundscheck=*/false)); + options.emplace_back(new DictionaryEncodeOptions()); + options.emplace_back( + new DictionaryEncodeOptions(DictionaryEncodeOptions::NullEncodingBehavior::ENCODE)); + options.emplace_back(new ArraySortOptions()); + options.emplace_back(new ArraySortOptions(SortOrder::Descending)); + options.emplace_back(new SortOptions()); + options.emplace_back(new SortOptions({SortKey("key", SortOrder::Ascending)})); + options.emplace_back(new SortOptions( + {SortKey("key", SortOrder::Descending), SortKey("value", SortOrder::Descending)})); + options.emplace_back(new PartitionNthOptions(/*pivot=*/0)); + options.emplace_back(new PartitionNthOptions(/*pivot=*/42)); + + for (size_t i = 0; i < options.size(); i++) { + const size_t prev_i = i == 0 ? options.size() - 1 : i - 1; + const FunctionOptions& cur = *options[i]; + const FunctionOptions& prev = *options[prev_i]; + SCOPED_TRACE(cur.type_name()); + SCOPED_TRACE(cur.ToString()); + ASSERT_EQ(cur, cur); + ASSERT_NE(cur, prev); + ASSERT_NE(prev, cur); + ASSERT_NE("", cur.ToString()); + + ASSERT_OK_AND_ASSIGN(auto serialized, cur.Serialize()); + const auto* type_name = cur.type_name(); + ASSERT_OK_AND_ASSIGN( + auto deserialized, + FunctionOptions::Deserialize(std::string(type_name, std::strlen(type_name)), + *serialized)); + ASSERT_TRUE(cur.Equals(*deserialized)); + } +} + struct ExecBatch; TEST(Arity, Basics) { diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index f8d15952e73..c88c924817c 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -41,7 +41,7 @@ namespace arrow { namespace compute { -struct FunctionOptions; +class FunctionOptions; /// \brief Base class for opaque kernel-specific state. For example, if there /// is some kind of initialization required. diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index 673802f99b0..8a0d9e62518 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -24,8 +24,10 @@ #include #include "arrow/compute/function.h" +#include "arrow/compute/function_internal.h" #include "arrow/compute/registry_internal.h" #include "arrow/status.h" +#include "arrow/util/logging.h" namespace arrow { namespace compute { @@ -57,6 +59,20 @@ class FunctionRegistry::FunctionRegistryImpl { return Status::OK(); } + Status AddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false) { + std::lock_guard mutation_guard(lock_); + + const std::string name = options_type->type_name(); + auto it = name_to_options_type_.find(name); + if (it != name_to_options_type_.end() && !allow_overwrite) { + return Status::KeyError( + "Already have a function options type registered with name: ", name); + } + name_to_options_type_[name] = options_type; + return Status::OK(); + } + Result> GetFunction(const std::string& name) const { auto it = name_to_function_.find(name); if (it == name_to_function_.end()) { @@ -74,11 +90,21 @@ class FunctionRegistry::FunctionRegistryImpl { return results; } + Result GetFunctionOptionsType( + const std::string& name) const { + auto it = name_to_options_type_.find(name); + if (it == name_to_options_type_.end()) { + return Status::KeyError("No function options type registered with name: ", name); + } + return it->second; + } + int num_functions() const { return static_cast(name_to_function_.size()); } private: std::mutex lock_; std::unordered_map> name_to_function_; + std::unordered_map name_to_options_type_; }; std::unique_ptr FunctionRegistry::Make() { @@ -99,6 +125,11 @@ Status FunctionRegistry::AddAlias(const std::string& target_name, return impl_->AddAlias(target_name, source_name); } +Status FunctionRegistry::AddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite) { + return impl_->AddFunctionOptionsType(options_type, allow_overwrite); +} + Result> FunctionRegistry::GetFunction( const std::string& name) const { return impl_->GetFunction(name); @@ -108,6 +139,11 @@ std::vector FunctionRegistry::GetFunctionNames() const { return impl_->GetFunctionNames(); } +Result FunctionRegistry::GetFunctionOptionsType( + const std::string& name) const { + return impl_->GetFunctionOptionsType(name); +} + int FunctionRegistry::num_functions() const { return impl_->num_functions(); } namespace internal { @@ -128,12 +164,16 @@ static std::unique_ptr CreateBuiltInRegistry() { RegisterScalarIfElse(registry.get()); RegisterScalarTemporal(registry.get()); + RegisterScalarOptions(registry.get()); + // Vector functions RegisterVectorHash(registry.get()); RegisterVectorSelection(registry.get()); RegisterVectorNested(registry.get()); RegisterVectorSort(registry.get()); + RegisterVectorOptions(registry.get()); + // Aggregate functions RegisterScalarAggregateBasic(registry.get()); RegisterScalarAggregateMode(registry.get()); @@ -142,6 +182,8 @@ static std::unique_ptr CreateBuiltInRegistry() { RegisterScalarAggregateVariance(registry.get()); RegisterHashAggregateBasic(registry.get()); + RegisterAggregateOptions(registry.get()); + return registry; } diff --git a/cpp/src/arrow/compute/registry.h b/cpp/src/arrow/compute/registry.h index b4456dc5b6b..e83036db6ac 100644 --- a/cpp/src/arrow/compute/registry.h +++ b/cpp/src/arrow/compute/registry.h @@ -32,6 +32,7 @@ namespace arrow { namespace compute { class Function; +class FunctionOptionsType; /// \brief A mutable central function registry for built-in functions as well /// as user-defined functions. Functions are implementations of @@ -58,6 +59,11 @@ class ARROW_EXPORT FunctionRegistry { /// function with the given name is not registered Status AddAlias(const std::string& target_name, const std::string& source_name); + /// \brief Add a new function options type to the registry. Returns Status::KeyError if + /// a function options type with the same name is already registered + Status AddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false); + /// \brief Retrieve a function by name from the registry Result> GetFunction(const std::string& name) const; @@ -65,6 +71,10 @@ class ARROW_EXPORT FunctionRegistry { /// displaying a manifest of available functions std::vector GetFunctionNames() const; + /// \brief Retrieve a function options type by name from the registry + Result GetFunctionOptionsType( + const std::string& name) const; + /// \brief The number of currently registered functions int num_functions() const; diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h index 68e0f2207f1..dd0271eb43d 100644 --- a/cpp/src/arrow/compute/registry_internal.h +++ b/cpp/src/arrow/compute/registry_internal.h @@ -37,12 +37,16 @@ void RegisterScalarFillNull(FunctionRegistry* registry); void RegisterScalarIfElse(FunctionRegistry* registry); void RegisterScalarTemporal(FunctionRegistry* registry); +void RegisterScalarOptions(FunctionRegistry* registry); + // Vector functions void RegisterVectorHash(FunctionRegistry* registry); void RegisterVectorSelection(FunctionRegistry* registry); void RegisterVectorNested(FunctionRegistry* registry); void RegisterVectorSort(FunctionRegistry* registry); +void RegisterVectorOptions(FunctionRegistry* registry); + // Aggregate functions void RegisterScalarAggregateBasic(FunctionRegistry* registry); void RegisterScalarAggregateMode(FunctionRegistry* registry); @@ -51,6 +55,8 @@ void RegisterScalarAggregateTDigest(FunctionRegistry* registry); void RegisterScalarAggregateVariance(FunctionRegistry* registry); void RegisterHashAggregateBasic(FunctionRegistry* registry); +void RegisterAggregateOptions(FunctionRegistry* registry); + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/type_fwd.h b/cpp/src/arrow/compute/type_fwd.h index 5370837f1b9..8a0d6de7f25 100644 --- a/cpp/src/arrow/compute/type_fwd.h +++ b/cpp/src/arrow/compute/type_fwd.h @@ -25,9 +25,9 @@ struct ValueDescr; namespace compute { class Function; -struct FunctionOptions; +class FunctionOptions; -struct CastOptions; +class CastOptions; struct ExecBatch; class ExecContext; diff --git a/cpp/src/arrow/testing/generator.cc b/cpp/src/arrow/testing/generator.cc index 71fad394d00..33371d55c6d 100644 --- a/cpp/src/arrow/testing/generator.cc +++ b/cpp/src/arrow/testing/generator.cc @@ -95,88 +95,16 @@ std::shared_ptr ConstantArrayGenerator::String(int64_t size, return ConstantArray(size, value); } -struct ScalarVectorToArrayImpl { - template ::BuilderType, - typename ScalarType = typename TypeTraits::ScalarType> - Status UseBuilder(const AppendScalar& append) { - BuilderType builder(type_, default_memory_pool()); - for (const auto& s : scalars_) { - if (s->is_valid) { - RETURN_NOT_OK(append(internal::checked_cast(*s), &builder)); - } else { - RETURN_NOT_OK(builder.AppendNull()); - } - } - return builder.FinishInternal(&data_); - } - - struct AppendValue { - template - Status operator()(const ScalarType& s, BuilderType* builder) const { - return builder->Append(s.value); - } - }; - - struct AppendBuffer { - template - Status operator()(const ScalarType& s, BuilderType* builder) const { - const Buffer& buffer = *s.value; - return builder->Append(util::string_view{buffer}); - } - }; - - template - enable_if_primitive_ctype Visit(const T&) { - return UseBuilder(AppendValue{}); - } - - template - enable_if_has_string_view Visit(const T&) { - return UseBuilder(AppendBuffer{}); - } - - Status Visit(const StructType& type) { - data_ = ArrayData::Make(type_, static_cast(scalars_.size()), - {/*null_bitmap=*/nullptr}); - data_->child_data.resize(type_->num_fields()); - - ScalarVector field_scalars(scalars_.size()); - - for (int field_index = 0; field_index < type.num_fields(); ++field_index) { - for (size_t i = 0; i < scalars_.size(); ++i) { - field_scalars[i] = - internal::checked_cast(scalars_[i].get())->value[field_index]; - } - - ARROW_ASSIGN_OR_RAISE(data_->child_data[field_index], - ScalarVectorToArrayImpl{}.Convert(field_scalars)); - } - return Status::OK(); - } - - Status Visit(const DataType& type) { - return Status::NotImplemented("ScalarVectorToArray for type ", type); - } - - Result> Convert(const ScalarVector& scalars) && { - if (scalars.size() == 0) { - return Status::NotImplemented("ScalarVectorToArray with no scalars"); - } - scalars_ = std::move(scalars); - type_ = scalars_[0]->type; - RETURN_NOT_OK(VisitTypeInline(*type_, this)); - return std::move(data_); - } - - std::shared_ptr type_; - ScalarVector scalars_; - std::shared_ptr data_; -}; - Result> ScalarVectorToArray(const ScalarVector& scalars) { - ARROW_ASSIGN_OR_RAISE(auto data, ScalarVectorToArrayImpl{}.Convert(scalars)); - return MakeArray(std::move(data)); + if (scalars.empty()) { + return Status::NotImplemented("ScalarVectorToArray with no scalars"); + } + std::unique_ptr builder; + RETURN_NOT_OK(MakeBuilder(default_memory_pool(), scalars[0]->type, &builder)); + RETURN_NOT_OK(builder->AppendScalars(scalars)); + std::shared_ptr out; + RETURN_NOT_OK(builder->Finish(&out)); + return out; } } // namespace arrow diff --git a/cpp/src/arrow/util/reflection_internal.h b/cpp/src/arrow/util/reflection_internal.h index 522815dd2be..0440a2eb563 100644 --- a/cpp/src/arrow/util/reflection_internal.h +++ b/cpp/src/arrow/util/reflection_internal.h @@ -21,6 +21,7 @@ #include #include +#include "arrow/type_traits.h" #include "arrow/util/string_view.h" namespace arrow { @@ -112,5 +113,21 @@ PropertyTuple MakeProperties(Properties... props) { return {std::make_tuple(props...)}; } +template +struct EnumTraits {}; + +template +struct BasicEnumTraits { + using CType = typename std::underlying_type::type; + using Type = typename CTypeTraits::ArrowType; + static std::array values() { return {Values...}; } +}; + +template +struct has_enum_traits : std::false_type {}; + +template +struct has_enum_traits::Type>> : std::true_type {}; + } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/reflection_test.cc b/cpp/src/arrow/util/reflection_test.cc index 4ffcf679ecc..fb3d3b8fb02 100644 --- a/cpp/src/arrow/util/reflection_test.cc +++ b/cpp/src/arrow/util/reflection_test.cc @@ -193,5 +193,32 @@ TEST(Reflection, FromStringToDataMembers) { EXPECT_EQ(PersonFromString("Person{age: 19, name: Genos}"), util::nullopt); } +enum class PersonType : int8_t { + EMPLOYEE, + CONTRACTOR, +}; + +template <> +struct EnumTraits + : BasicEnumTraits { + static std::string name() { return "PersonType"; } + static std::string value_name(PersonType value) { + switch (value) { + case PersonType::EMPLOYEE: + return "EMPLOYEE"; + case PersonType::CONTRACTOR: + return "CONTRACTOR"; + } + return ""; + } +}; + +TEST(Reflection, EnumTraits) { + static_assert(!has_enum_traits::value, ""); + static_assert(has_enum_traits::value, ""); + static_assert(std::is_same::CType, int8_t>::value, ""); + static_assert(std::is_same::Type, Int8Type>::value, ""); +} + } // namespace internal } // namespace arrow diff --git a/python/pyarrow/_compute.pxd b/python/pyarrow/_compute.pxd index e187ed75b69..8358271efa7 100644 --- a/python/pyarrow/_compute.pxd +++ b/python/pyarrow/_compute.pxd @@ -23,5 +23,8 @@ from pyarrow.includes.libarrow cimport * cdef class FunctionOptions(_Weakrefable): + cdef: + unique_ptr[CFunctionOptions] wrapped cdef const CFunctionOptions* get_options(self) except NULL + cdef void init(self, unique_ptr[CFunctionOptions] options) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index ae08a5596f3..c8393103dc5 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -526,9 +526,70 @@ def call_function(name, args, options=None, memory_pool=None): cdef class FunctionOptions(_Weakrefable): + __slots__ = () # avoid mistakingly creating attributes cdef const CFunctionOptions* get_options(self) except NULL: - raise NotImplementedError("Unimplemented base options") + return self.wrapped.get() + + cdef void init(self, unique_ptr[CFunctionOptions] options): + self.wrapped = move(options) + + def serialize(self): + cdef: + CResult[shared_ptr[CBuffer]] res = self.get_options().Serialize() + shared_ptr[CBuffer] c_buf = GetResultValue(res) + return pyarrow_wrap_buffer(c_buf) + + @staticmethod + def deserialize(buf): + cdef: + shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(buf) + CResult[unique_ptr[CFunctionOptions]] maybe_options = \ + DeserializeFunctionOptions(deref(c_buf)) + unique_ptr[CFunctionOptions] c_options + c_options = move(GetResultValue(move(maybe_options))) + type_name = frombytes(c_options.get().options_type().type_name()) + mapping = { + "array_sort": ArraySortOptions, + "cast": CastOptions, + "dictionary_encode": DictionaryEncodeOptions, + "element_wise_aggregate": ElementWiseAggregateOptions, + "extract_regex": ExtractRegexOptions, + "filter": FilterOptions, + "index": IndexOptions, + "join": JoinOptions, + "match_substring": MatchSubstringOptions, + "mode": ModeOptions, + "pad": PadOptions, + "partition_nth": PartitionNthOptions, + "project": ProjectOptions, + "quantile": QuantileOptions, + "replace_slice": ReplaceSliceOptions, + "replace_substring": ReplaceSubstringOptions, + "set_lookup": SetLookupOptions, + "scalar_aggregate": ScalarAggregateOptions, + "slice": SliceOptions, + "sort": SortOptions, + "split": SplitOptions, + "split_pattern": SplitPatternOptions, + "strptime": StrptimeOptions, + "t_digest": TDigestOptions, + "take": TakeOptions, + "trim": TrimOptions, + "variance": VarianceOptions, + } + if type_name not in mapping: + raise ValueError(f"Cannot deserialize '{type_name}'") + klass = mapping[type_name] + options = klass.__new__(klass) + ( options).init(move(c_options)) + return options + + def __repr__(self): + return frombytes(self.get_options().ToString()) + + def __eq__(self, FunctionOptions other): + return self.get_options().Equals(deref(other.get_options())) # NOTE: @@ -541,17 +602,16 @@ cdef class FunctionOptions(_Weakrefable): cdef class _CastOptions(FunctionOptions): cdef: - unique_ptr[CCastOptions] options + CCastOptions* options - __slots__ = () # avoid mistakingly creating attributes - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.options.get() + cdef void init(self, unique_ptr[CFunctionOptions] options): + FunctionOptions.init(self, move(options)) + self.options = self.wrapped.get() def _set_options(self, DataType target_type, allow_int_overflow, allow_time_truncate, allow_time_overflow, allow_float_truncate, allow_invalid_utf8): - self.options.reset(new CCastOptions()) + self.init(unique_ptr[CFunctionOptions](new CCastOptions())) self._set_type(target_type) if allow_int_overflow is not None: self.allow_int_overflow = allow_int_overflow @@ -571,10 +631,12 @@ cdef class _CastOptions(FunctionOptions): ) def _set_safe(self): - self.options.reset(new CCastOptions(CCastOptions.Safe())) + self.init(unique_ptr[CFunctionOptions]( + new CCastOptions(CCastOptions.Safe()))) def _set_unsafe(self): - self.options.reset(new CCastOptions(CCastOptions.Unsafe())) + self.init(unique_ptr[CFunctionOptions]( + new CCastOptions(CCastOptions.Unsafe()))) def is_safe(self): return not ( @@ -651,15 +713,8 @@ class CastOptions(_CastOptions): cdef class _ElementWiseAggregateOptions(FunctionOptions): - cdef: - unique_ptr[CElementWiseAggregateOptions] element_wise_aggregate_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.element_wise_aggregate_options.get() - def _set_options(self, bint skip_nulls): - self.element_wise_aggregate_options.reset( - new CElementWiseAggregateOptions(skip_nulls)) + self.wrapped.reset(new CElementWiseAggregateOptions(skip_nulls)) class ElementWiseAggregateOptions(_ElementWiseAggregateOptions): @@ -668,12 +723,6 @@ class ElementWiseAggregateOptions(_ElementWiseAggregateOptions): cdef class _JoinOptions(FunctionOptions): - cdef: - unique_ptr[CJoinOptions] join_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.join_options.get() - def _set_options(self, null_handling, null_replacement): cdef: CJoinNullHandlingBehavior c_null_handling = \ @@ -689,7 +738,7 @@ cdef class _JoinOptions(FunctionOptions): raise ValueError( '"{}" is not a valid null_handling' .format(null_handling)) - self.join_options.reset( + self.wrapped.reset( new CJoinOptions(c_null_handling, c_null_replacement)) @@ -699,14 +748,8 @@ class JoinOptions(_JoinOptions): cdef class _MatchSubstringOptions(FunctionOptions): - cdef: - unique_ptr[CMatchSubstringOptions] match_substring_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.match_substring_options.get() - def _set_options(self, pattern, bint ignore_case): - self.match_substring_options.reset( + self.wrapped.reset( new CMatchSubstringOptions(tobytes(pattern), ignore_case)) @@ -716,15 +759,8 @@ class MatchSubstringOptions(_MatchSubstringOptions): cdef class _PadOptions(FunctionOptions): - cdef: - unique_ptr[CPadOptions] pad_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.pad_options.get() - def _set_options(self, width, padding): - self.pad_options.reset( - new CPadOptions(width, tobytes(padding))) + self.wrapped.reset(new CPadOptions(width, tobytes(padding))) class PadOptions(_PadOptions): @@ -733,15 +769,8 @@ class PadOptions(_PadOptions): cdef class _TrimOptions(FunctionOptions): - cdef: - unique_ptr[CTrimOptions] trim_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.trim_options.get() - def _set_options(self, characters): - self.trim_options.reset( - new CTrimOptions(tobytes(characters))) + self.wrapped.reset(new CTrimOptions(tobytes(characters))) class TrimOptions(_TrimOptions): @@ -750,14 +779,8 @@ class TrimOptions(_TrimOptions): cdef class _ReplaceSliceOptions(FunctionOptions): - cdef: - unique_ptr[CReplaceSliceOptions] replace_slice_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.replace_slice_options.get() - def _set_options(self, start, stop, replacement): - self.replace_slice_options.reset( + self.wrapped.reset( new CReplaceSliceOptions(start, stop, tobytes(replacement)) ) @@ -768,14 +791,8 @@ class ReplaceSliceOptions(_ReplaceSliceOptions): cdef class _ReplaceSubstringOptions(FunctionOptions): - cdef: - unique_ptr[CReplaceSubstringOptions] replace_substring_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.replace_substring_options.get() - def _set_options(self, pattern, replacement, max_replacements): - self.replace_substring_options.reset( + self.wrapped.reset( new CReplaceSubstringOptions(tobytes(pattern), tobytes(replacement), max_replacements) @@ -788,14 +805,8 @@ class ReplaceSubstringOptions(_ReplaceSubstringOptions): cdef class _ExtractRegexOptions(FunctionOptions): - cdef: - unique_ptr[CExtractRegexOptions] extract_regex_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.extract_regex_options.get() - def _set_options(self, pattern): - self.extract_regex_options.reset( + self.wrapped.reset( new CExtractRegexOptions(tobytes(pattern))) @@ -805,15 +816,8 @@ class ExtractRegexOptions(_ExtractRegexOptions): cdef class _SliceOptions(FunctionOptions): - cdef: - unique_ptr[CSliceOptions] slice_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.slice_options.get() - def _set_options(self, start, stop, step): - self.slice_options.reset( - new CSliceOptions(start, stop, step)) + self.wrapped.reset(new CSliceOptions(start, stop, step)) class SliceOptions(_SliceOptions): @@ -822,18 +826,12 @@ class SliceOptions(_SliceOptions): cdef class _FilterOptions(FunctionOptions): - cdef: - unique_ptr[CFilterOptions] filter_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.filter_options.get() - def _set_options(self, null_selection_behavior): if null_selection_behavior == 'drop': - self.filter_options.reset( + self.wrapped.reset( new CFilterOptions(CFilterNullSelectionBehavior_DROP)) elif null_selection_behavior == 'emit_null': - self.filter_options.reset( + self.wrapped.reset( new CFilterOptions(CFilterNullSelectionBehavior_EMIT_NULL)) else: raise ValueError( @@ -847,19 +845,13 @@ class FilterOptions(_FilterOptions): cdef class _DictionaryEncodeOptions(FunctionOptions): - cdef: - unique_ptr[CDictionaryEncodeOptions] dictionary_encode_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.dictionary_encode_options.get() - def _set_options(self, null_encoding_behavior): if null_encoding_behavior == 'encode': - self.dictionary_encode_options.reset( + self.wrapped.reset( new CDictionaryEncodeOptions( CDictionaryEncodeNullEncodingBehavior_ENCODE)) elif null_encoding_behavior == 'mask': - self.dictionary_encode_options.reset( + self.wrapped.reset( new CDictionaryEncodeOptions( CDictionaryEncodeNullEncodingBehavior_MASK)) else: @@ -873,14 +865,8 @@ class DictionaryEncodeOptions(_DictionaryEncodeOptions): cdef class _TakeOptions(FunctionOptions): - cdef: - unique_ptr[CTakeOptions] take_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.take_options.get() - def _set_options(self, boundscheck): - self.take_options.reset(new CTakeOptions(boundscheck)) + self.wrapped.reset(new CTakeOptions(boundscheck)) class TakeOptions(_TakeOptions): @@ -889,14 +875,8 @@ class TakeOptions(_TakeOptions): cdef class _PartitionNthOptions(FunctionOptions): - cdef: - unique_ptr[CPartitionNthOptions] partition_nth_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.partition_nth_options.get() - def _set_options(self, int64_t pivot): - self.partition_nth_options.reset(new CPartitionNthOptions(pivot)) + self.wrapped.reset(new CPartitionNthOptions(pivot)) class PartitionNthOptions(_PartitionNthOptions): @@ -905,18 +885,12 @@ class PartitionNthOptions(_PartitionNthOptions): cdef class _ProjectOptions(FunctionOptions): - cdef: - unique_ptr[CProjectOptions] project_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.project_options.get() - def _set_options(self, field_names): cdef: vector[c_string] c_field_names for n in field_names: c_field_names.push_back(tobytes(n)) - self.project_options.reset(new CProjectOptions(field_names)) + self.wrapped.reset(new CProjectOptions(field_names)) class ProjectOptions(_ProjectOptions): @@ -925,14 +899,8 @@ class ProjectOptions(_ProjectOptions): cdef class _ScalarAggregateOptions(FunctionOptions): - cdef: - unique_ptr[CScalarAggregateOptions] scalar_aggregate_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.scalar_aggregate_options.get() - def _set_options(self, skip_nulls, min_count): - self.scalar_aggregate_options.reset( + self.wrapped.reset( new CScalarAggregateOptions(skip_nulls, min_count)) @@ -942,15 +910,8 @@ class ScalarAggregateOptions(_ScalarAggregateOptions): cdef class _IndexOptions(FunctionOptions): - cdef: - unique_ptr[CIndexOptions] index_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.index_options.get() - def _set_options(self, Scalar scalar): - self.index_options.reset( - new CIndexOptions(pyarrow_unwrap_scalar(scalar))) + self.wrapped.reset(new CIndexOptions(pyarrow_unwrap_scalar(scalar))) class IndexOptions(_IndexOptions): @@ -968,14 +929,8 @@ class IndexOptions(_IndexOptions): cdef class _ModeOptions(FunctionOptions): - cdef: - unique_ptr[CModeOptions] mode_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.mode_options.get() - def _set_options(self, n): - self.mode_options.reset(new CModeOptions(n)) + self.wrapped.reset(new CModeOptions(n)) class ModeOptions(_ModeOptions): @@ -985,12 +940,8 @@ class ModeOptions(_ModeOptions): cdef class _SetLookupOptions(FunctionOptions): cdef: - unique_ptr[CSetLookupOptions] set_lookup_options unique_ptr[CDatum] valset - cdef const CFunctionOptions* get_options(self) except NULL: - return self.set_lookup_options.get() - def _set_options(self, value_set, c_bool skip_nulls): if isinstance(value_set, Array): self.valset.reset(new CDatum(( value_set).sp_array)) @@ -1003,9 +954,8 @@ cdef class _SetLookupOptions(FunctionOptions): else: raise ValueError('"{}" is not a valid value_set'.format(value_set)) - self.set_lookup_options.reset( - new CSetLookupOptions(deref(self.valset), skip_nulls) - ) + self.wrapped.reset( + new CSetLookupOptions(deref(self.valset), skip_nulls)) class SetLookupOptions(_SetLookupOptions): @@ -1014,27 +964,20 @@ class SetLookupOptions(_SetLookupOptions): cdef class _StrptimeOptions(FunctionOptions): - cdef: - unique_ptr[CStrptimeOptions] strptime_options - TimeUnit time_unit - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.strptime_options.get() - def _set_options(self, format, unit): if unit == 's': - self.time_unit = TimeUnit_SECOND + time_unit = TimeUnit_SECOND elif unit == 'ms': - self.time_unit = TimeUnit_MILLI + time_unit = TimeUnit_MILLI elif unit == 'us': - self.time_unit = TimeUnit_MICRO + time_unit = TimeUnit_MICRO elif unit == 'ns': - self.time_unit = TimeUnit_NANO + time_unit = TimeUnit_NANO else: raise ValueError('"{}" is not a valid time unit'.format(unit)) - self.strptime_options.reset( - new CStrptimeOptions(tobytes(format), self.time_unit) + self.wrapped.reset( + new CStrptimeOptions(tobytes(format), time_unit) ) @@ -1044,14 +987,8 @@ class StrptimeOptions(_StrptimeOptions): cdef class _VarianceOptions(FunctionOptions): - cdef: - unique_ptr[CVarianceOptions] variance_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.variance_options.get() - def _set_options(self, ddof): - self.variance_options.reset(new CVarianceOptions(ddof)) + self.wrapped.reset(new CVarianceOptions(ddof)) class VarianceOptions(_VarianceOptions): @@ -1060,14 +997,8 @@ class VarianceOptions(_VarianceOptions): cdef class _SplitOptions(FunctionOptions): - cdef: - unique_ptr[CSplitOptions] split_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.split_options.get() - def _set_options(self, max_splits, reverse): - self.split_options.reset( + self.wrapped.reset( new CSplitOptions(max_splits, reverse)) @@ -1077,14 +1008,8 @@ class SplitOptions(_SplitOptions): cdef class _SplitPatternOptions(FunctionOptions): - cdef: - unique_ptr[CSplitPatternOptions] split_pattern_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.split_pattern_options.get() - def _set_options(self, pattern, max_splits, reverse): - self.split_pattern_options.reset( + self.wrapped.reset( new CSplitPatternOptions(tobytes(pattern), max_splits, reverse)) @@ -1094,19 +1019,11 @@ class SplitPatternOptions(_SplitPatternOptions): cdef class _ArraySortOptions(FunctionOptions): - cdef: - unique_ptr[CArraySortOptions] array_sort_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.array_sort_options.get() - def _set_options(self, order): if order == "ascending": - self.array_sort_options.reset( - new CArraySortOptions(CSortOrder_Ascending)) + self.wrapped.reset(new CArraySortOptions(CSortOrder_Ascending)) elif order == "descending": - self.array_sort_options.reset( - new CArraySortOptions(CSortOrder_Descending)) + self.wrapped.reset(new CArraySortOptions(CSortOrder_Descending)) else: raise ValueError( "{!r} is not a valid order".format(order) @@ -1119,12 +1036,6 @@ class ArraySortOptions(_ArraySortOptions): cdef class _SortOptions(FunctionOptions): - cdef: - unique_ptr[CSortOptions] sort_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.sort_options.get() - def _set_options(self, sort_keys): cdef: vector[CSortKey] c_sort_keys @@ -1143,7 +1054,7 @@ cdef class _SortOptions(FunctionOptions): c_name = tobytes(name) c_sort_keys.push_back(CSortKey(c_name, c_order)) - self.sort_options.reset(new CSortOptions(c_sort_keys)) + self.wrapped.reset(new CSortOptions(c_sort_keys)) class SortOptions(_SortOptions): @@ -1154,12 +1065,6 @@ class SortOptions(_SortOptions): cdef class _QuantileOptions(FunctionOptions): - cdef: - unique_ptr[CQuantileOptions] quantile_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.quantile_options.get() - def _set_options(self, quantiles, interp): interp_dict = { 'linear': CQuantileInterp_LINEAR, @@ -1172,7 +1077,7 @@ cdef class _QuantileOptions(FunctionOptions): raise ValueError( '{!r} is not a valid interpolation' .format(interp)) - self.quantile_options.reset( + self.wrapped.reset( new CQuantileOptions(quantiles, interp_dict[interp])) @@ -1184,14 +1089,8 @@ class QuantileOptions(_QuantileOptions): cdef class _TDigestOptions(FunctionOptions): - cdef: - unique_ptr[CTDigestOptions] tdigest_options - - cdef const CFunctionOptions* get_options(self) except NULL: - return self.tdigest_options.get() - def _set_options(self, quantiles, delta, buffer_size): - self.tdigest_options.reset( + self.wrapped.reset( new CTDigestOptions(quantiles, delta, buffer_size)) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 653a2b83781..07983b79f40 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1751,8 +1751,18 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: vector[c_string] arg_names c_string options_class + cdef cppclass CFunctionOptionsType" arrow::compute::FunctionOptionsType": + const char* type_name() const + cdef cppclass CFunctionOptions" arrow::compute::FunctionOptions": - pass + const CFunctionOptionsType* options_type() const + c_bool Equals(const CFunctionOptions& other) + c_string ToString() + CResult[shared_ptr[CBuffer]] Serialize() const + + @staticmethod + CResult[unique_ptr[CFunctionOptions]] Deserialize( + const c_string& type_name, const CBuffer&) cdef cppclass CFunction" arrow::compute::Function": const c_string& name() const @@ -1843,9 +1853,11 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: c_bool reverse cdef cppclass CSplitPatternOptions \ - "arrow::compute::SplitPatternOptions"(CSplitOptions): + "arrow::compute::SplitPatternOptions"(CFunctionOptions): CSplitPatternOptions(c_string pattern, int64_t max_splits, c_bool reverse) + int64_t max_splits + c_bool reverse c_string pattern cdef cppclass CReplaceSliceOptions \ @@ -2027,6 +2039,25 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: c_bool skip_nulls +cdef extern from * namespace "arrow::compute": + # inlined from compute/function_internal.h to avoid exposing + # implementation details + """ + #include "arrow/compute/function.h" + namespace arrow { + namespace compute { + namespace internal { + Result> DeserializeFunctionOptions( + const Buffer& buffer); + } // namespace internal + } // namespace compute + } // namespace arrow + """ + CResult[unique_ptr[CFunctionOptions]] DeserializeFunctionOptions\ + " arrow::compute::internal::DeserializeFunctionOptions"( + const CBuffer& buffer) + + cdef extern from "arrow/python/api.h" namespace "arrow::py": # Requires GIL CResult[shared_ptr[CDataType]] InferArrowType( diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 3a10da0ca2b..264da5805e1 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -111,6 +111,41 @@ def test_exported_option_classes(): param.VAR_KEYWORD) +def test_option_class_equality(): + options = [ + pc.CastOptions.safe(pa.int8()), + pc.ExtractRegexOptions("pattern"), + pc.IndexOptions(pa.scalar(1)), + pc.MatchSubstringOptions("pattern"), + pc.PadOptions(5, " "), + pc.PartitionNthOptions(1), + pc.ProjectOptions([b"field", b"names"]), + pc.ReplaceSliceOptions(start=0, stop=1, replacement="a"), + pc.ReplaceSubstringOptions("a", "b"), + pc.SetLookupOptions(value_set=pa.array([1])), + pc.SliceOptions(start=0, stop=1, step=1), + pc.SplitPatternOptions(pattern="pattern"), + pc.StrptimeOptions("%Y", "s"), + pc.TrimOptions(" "), + ] + classes = {type(option) for option in options} + for cls in exported_option_classes: + if cls not in classes: + try: + options.append(cls()) + except TypeError: + pytest.fail(f"Options class is not tested: {cls}") + for option in options: + assert option == option + assert repr(option) + buf = option.serialize() + deserialized = pc.FunctionOptions.deserialize(buf) + assert option == deserialized + assert repr(option) == repr(deserialized) + for option1, option2 in zip(options, options[1:]): + assert option1 != option2 + + def test_list_functions(): assert len(pc.list_functions()) > 10 assert "add" in pc.list_functions()