From 50eb6d7f7f3caa45fb5282bc64ce20e0454d2585 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 1 Jul 2021 08:47:20 -0400 Subject: [PATCH] ARROW-13235: [C++][Python] Simplify mapping of function options --- cpp/src/arrow/compute/api_aggregate.h | 12 ++++----- cpp/src/arrow/compute/api_scalar.cc | 7 +++++ cpp/src/arrow/compute/api_scalar.h | 36 +++++++++++++------------- cpp/src/arrow/compute/api_vector.h | 12 ++++----- cpp/src/arrow/compute/cast.h | 2 +- cpp/src/arrow/compute/function_test.cc | 2 ++ python/pyarrow/_compute.pyx | 34 +++--------------------- 7 files changed, 43 insertions(+), 62 deletions(-) diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h index 9be0b406aa4..7b6e2ef96de 100644 --- a/cpp/src/arrow/compute/api_aggregate.h +++ b/cpp/src/arrow/compute/api_aggregate.h @@ -46,7 +46,7 @@ class ExecContext; class ARROW_EXPORT ScalarAggregateOptions : public FunctionOptions { public: explicit ScalarAggregateOptions(bool skip_nulls = true, uint32_t min_count = 1); - constexpr static char const kTypeName[] = "scalar_aggregate"; + constexpr static char const kTypeName[] = "ScalarAggregateOptions"; static ScalarAggregateOptions Defaults() { return ScalarAggregateOptions{}; } bool skip_nulls; @@ -60,7 +60,7 @@ class ARROW_EXPORT ScalarAggregateOptions : public FunctionOptions { class ARROW_EXPORT ModeOptions : public FunctionOptions { public: explicit ModeOptions(int64_t n = 1); - constexpr static char const kTypeName[] = "mode"; + constexpr static char const kTypeName[] = "ModeOptions"; static ModeOptions Defaults() { return ModeOptions{}; } int64_t n = 1; @@ -73,7 +73,7 @@ class ARROW_EXPORT ModeOptions : public FunctionOptions { class ARROW_EXPORT VarianceOptions : public FunctionOptions { public: explicit VarianceOptions(int ddof = 0); - constexpr static char const kTypeName[] = "variance"; + constexpr static char const kTypeName[] = "VarianceOptions"; static VarianceOptions Defaults() { return VarianceOptions{}; } int ddof = 0; @@ -98,7 +98,7 @@ class ARROW_EXPORT QuantileOptions : public FunctionOptions { explicit QuantileOptions(std::vector q, enum Interpolation interpolation = LINEAR); - constexpr static char const kTypeName[] = "quantile"; + constexpr static char const kTypeName[] = "QuantileOptions"; static QuantileOptions Defaults() { return QuantileOptions{}; } /// quantile must be between 0 and 1 inclusive @@ -115,7 +115,7 @@ class ARROW_EXPORT TDigestOptions : public FunctionOptions { uint32_t buffer_size = 500); explicit TDigestOptions(std::vector q, uint32_t delta = 100, uint32_t buffer_size = 500); - constexpr static char const kTypeName[] = "t_digest"; + constexpr static char const kTypeName[] = "TDigestOptions"; static TDigestOptions Defaults() { return TDigestOptions{}; } /// quantile must be between 0 and 1 inclusive @@ -132,7 +132,7 @@ class ARROW_EXPORT IndexOptions : public FunctionOptions { explicit IndexOptions(std::shared_ptr value); // Default constructor for serialization IndexOptions(); - constexpr static char const kTypeName[] = "index"; + constexpr static char const kTypeName[] = "IndexOptions"; std::shared_ptr value; }; diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 11b5b45b7a0..2021c8a30c6 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -110,6 +110,8 @@ using ::arrow::internal::checked_cast; namespace internal { namespace { using ::arrow::internal::DataMember; +static auto kArithmeticOptionsType = GetFunctionOptionsType( + DataMember("check_overflow", &ArithmeticOptions::check_overflow)); static auto kElementWiseAggregateOptionsType = GetFunctionOptionsType( DataMember("skip_nulls", &ElementWiseAggregateOptions::skip_nulls)); @@ -159,6 +161,10 @@ static auto kProjectOptionsType = GetFunctionOptionsType( } // namespace } // namespace internal +ArithmeticOptions::ArithmeticOptions(bool check_overflow) + : FunctionOptions(internal::kArithmeticOptionsType), check_overflow(check_overflow) {} +constexpr char ArithmeticOptions::kTypeName[]; + ElementWiseAggregateOptions::ElementWiseAggregateOptions(bool skip_nulls) : FunctionOptions(internal::kElementWiseAggregateOptionsType), skip_nulls(skip_nulls) {} @@ -274,6 +280,7 @@ constexpr char ProjectOptions::kTypeName[]; namespace internal { void RegisterScalarOptions(FunctionRegistry* registry) { + DCHECK_OK(registry->AddFunctionOptionsType(kArithmeticOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kElementWiseAggregateOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kJoinOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kMatchSubstringOptionsType)); diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index bacb287d6bc..89b4faca940 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -37,17 +37,17 @@ namespace compute { /// /// @{ -struct ARROW_EXPORT ArithmeticOptions { +class ARROW_EXPORT ArithmeticOptions : public FunctionOptions { public: - explicit ArithmeticOptions(bool check_overflow = false) - : check_overflow(check_overflow) {} + explicit ArithmeticOptions(bool check_overflow = false); + constexpr static char const kTypeName[] = "ArithmeticOptions"; bool check_overflow; }; class ARROW_EXPORT ElementWiseAggregateOptions : public FunctionOptions { public: explicit ElementWiseAggregateOptions(bool skip_nulls = true); - constexpr static char const kTypeName[] = "element_wise_aggregate"; + constexpr static char const kTypeName[] = "ElementWiseAggregateOptions"; static ElementWiseAggregateOptions Defaults() { return ElementWiseAggregateOptions{}; } bool skip_nulls; @@ -67,7 +67,7 @@ class ARROW_EXPORT JoinOptions : public FunctionOptions { }; explicit JoinOptions(NullHandlingBehavior null_handling = EMIT_NULL, std::string null_replacement = ""); - constexpr static char const kTypeName[] = "join"; + constexpr static char const kTypeName[] = "JoinOptions"; static JoinOptions Defaults() { return JoinOptions(); } NullHandlingBehavior null_handling; std::string null_replacement; @@ -77,7 +77,7 @@ class ARROW_EXPORT MatchSubstringOptions : public FunctionOptions { public: explicit MatchSubstringOptions(std::string pattern, bool ignore_case = false); MatchSubstringOptions(); - constexpr static char const kTypeName[] = "match_substring"; + constexpr static char const kTypeName[] = "MatchSubstringOptions"; /// The exact substring (or regex, depending on kernel) to look for inside input values. std::string pattern; @@ -88,7 +88,7 @@ class ARROW_EXPORT MatchSubstringOptions : public FunctionOptions { class ARROW_EXPORT SplitOptions : public FunctionOptions { public: explicit SplitOptions(int64_t max_splits = -1, bool reverse = false); - constexpr static char const kTypeName[] = "split"; + constexpr static char const kTypeName[] = "SplitOptions"; /// Maximum number of splits allowed, or unlimited when -1 int64_t max_splits; @@ -101,7 +101,7 @@ class ARROW_EXPORT SplitPatternOptions : public FunctionOptions { explicit SplitPatternOptions(std::string pattern, int64_t max_splits = -1, bool reverse = false); SplitPatternOptions(); - constexpr static char const kTypeName[] = "split_pattern"; + constexpr static char const kTypeName[] = "SplitPatternOptions"; /// The exact substring to split on. std::string pattern; @@ -115,7 +115,7 @@ class ARROW_EXPORT ReplaceSliceOptions : public FunctionOptions { public: explicit ReplaceSliceOptions(int64_t start, int64_t stop, std::string replacement); ReplaceSliceOptions(); - constexpr static char const kTypeName[] = "replace_slice"; + constexpr static char const kTypeName[] = "ReplaceSliceOptions"; /// Index to start slicing at int64_t start; @@ -130,7 +130,7 @@ class ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions { explicit ReplaceSubstringOptions(std::string pattern, std::string replacement, int64_t max_replacements = -1); ReplaceSubstringOptions(); - constexpr static char const kTypeName[] = "replace_substring"; + constexpr static char const kTypeName[] = "ReplaceSubstringOptions"; /// Pattern to match, literal, or regular expression depending on which kernel is used std::string pattern; @@ -144,7 +144,7 @@ class ARROW_EXPORT ExtractRegexOptions : public FunctionOptions { public: explicit ExtractRegexOptions(std::string pattern); ExtractRegexOptions(); - constexpr static char const kTypeName[] = "extract_regex"; + constexpr static char const kTypeName[] = "ExtractRegexOptions"; /// Regular expression with named capture fields std::string pattern; @@ -155,7 +155,7 @@ class ARROW_EXPORT SetLookupOptions : public FunctionOptions { public: explicit SetLookupOptions(Datum value_set, bool skip_nulls = false); SetLookupOptions(); - constexpr static char const kTypeName[] = "set_lookup"; + constexpr static char const kTypeName[] = "SetLookupOptions"; /// The set of values to look up input values into. Datum value_set; @@ -172,7 +172,7 @@ class ARROW_EXPORT StrptimeOptions : public FunctionOptions { public: explicit StrptimeOptions(std::string format, TimeUnit::type unit); StrptimeOptions(); - constexpr static char const kTypeName[] = "strptime"; + constexpr static char const kTypeName[] = "StrptimeOptions"; std::string format; TimeUnit::type unit; @@ -182,7 +182,7 @@ class ARROW_EXPORT PadOptions : public FunctionOptions { public: explicit PadOptions(int64_t width, std::string padding = " "); PadOptions(); - constexpr static char const kTypeName[] = "pad"; + constexpr static char const kTypeName[] = "PadOptions"; /// The desired string length. int64_t width; @@ -194,7 +194,7 @@ class ARROW_EXPORT TrimOptions : public FunctionOptions { public: explicit TrimOptions(std::string characters); TrimOptions(); - constexpr static char const kTypeName[] = "trim"; + constexpr static char const kTypeName[] = "TrimOptions"; /// The individual characters that can be trimmed from the string. std::string characters; @@ -205,7 +205,7 @@ class ARROW_EXPORT SliceOptions : public FunctionOptions { explicit SliceOptions(int64_t start, int64_t stop = std::numeric_limits::max(), int64_t step = 1); SliceOptions(); - constexpr static char const kTypeName[] = "slice"; + constexpr static char const kTypeName[] = "SliceOptions"; int64_t start, stop, step; }; @@ -222,7 +222,7 @@ class ARROW_EXPORT CompareOptions : public FunctionOptions { public: explicit CompareOptions(CompareOperator op); CompareOptions(); - constexpr static char const kTypeName[] = "compare"; + constexpr static char const kTypeName[] = "CompareOptions"; enum CompareOperator op; }; @@ -232,7 +232,7 @@ class ARROW_EXPORT ProjectOptions : public FunctionOptions { std::vector> m); explicit ProjectOptions(std::vector n); ProjectOptions(); - constexpr static char const kTypeName[] = "project"; + constexpr static char const kTypeName[] = "ProjectOptions"; /// Names for wrapped columns std::vector field_names; diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 2282b0098f9..6021492320e 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -43,7 +43,7 @@ class ARROW_EXPORT FilterOptions : public FunctionOptions { }; explicit FilterOptions(NullSelectionBehavior null_selection = DROP); - constexpr static char const kTypeName[] = "filter"; + constexpr static char const kTypeName[] = "FilterOptions"; static FilterOptions Defaults() { return FilterOptions(); } NullSelectionBehavior null_selection_behavior = DROP; @@ -52,7 +52,7 @@ class ARROW_EXPORT FilterOptions : public FunctionOptions { class ARROW_EXPORT TakeOptions : public FunctionOptions { public: explicit TakeOptions(bool boundscheck = true); - constexpr static char const kTypeName[] = "take"; + constexpr static char const kTypeName[] = "TakeOptions"; static TakeOptions BoundsCheck() { return TakeOptions(true); } static TakeOptions NoBoundsCheck() { return TakeOptions(false); } static TakeOptions Defaults() { return BoundsCheck(); } @@ -72,7 +72,7 @@ class ARROW_EXPORT DictionaryEncodeOptions : public FunctionOptions { }; explicit DictionaryEncodeOptions(NullEncodingBehavior null_encoding = MASK); - constexpr static char const kTypeName[] = "dictionary_encode"; + constexpr static char const kTypeName[] = "DictionaryEncodeOptions"; static DictionaryEncodeOptions Defaults() { return DictionaryEncodeOptions(); } NullEncodingBehavior null_encoding_behavior = MASK; @@ -104,7 +104,7 @@ class ARROW_EXPORT SortKey : public util::EqualityComparable { class ARROW_EXPORT ArraySortOptions : public FunctionOptions { public: explicit ArraySortOptions(SortOrder order = SortOrder::Ascending); - constexpr static char const kTypeName[] = "array_sort"; + constexpr static char const kTypeName[] = "ArraySortOptions"; static ArraySortOptions Defaults() { return ArraySortOptions{}; } SortOrder order; @@ -113,7 +113,7 @@ class ARROW_EXPORT ArraySortOptions : public FunctionOptions { class ARROW_EXPORT SortOptions : public FunctionOptions { public: explicit SortOptions(std::vector sort_keys = {}); - constexpr static char const kTypeName[] = "sort"; + constexpr static char const kTypeName[] = "SortOptions"; static SortOptions Defaults() { return SortOptions{}; } std::vector sort_keys; @@ -124,7 +124,7 @@ class ARROW_EXPORT PartitionNthOptions : public FunctionOptions { public: explicit PartitionNthOptions(int64_t pivot); PartitionNthOptions() : PartitionNthOptions(0) {} - constexpr static char const kTypeName[] = "partition_nth"; + constexpr static char const kTypeName[] = "PartitionNthOptions"; /// The index into the equivalent sorted array of the partition pivot element. int64_t pivot; diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index 8abd2a71bca..131f57f892f 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -45,7 +45,7 @@ class ARROW_EXPORT CastOptions : public FunctionOptions { public: explicit CastOptions(bool safe = true); - constexpr static char const kTypeName[] = "cast"; + constexpr static char const kTypeName[] = "CastOptions"; static CastOptions Safe(std::shared_ptr to_type = NULLPTR) { CastOptions safe(true); safe.to_type = std::move(to_type); diff --git a/cpp/src/arrow/compute/function_test.cc b/cpp/src/arrow/compute/function_test.cc index 4c42ce39600..bbe514af09a 100644 --- a/cpp/src/arrow/compute/function_test.cc +++ b/cpp/src/arrow/compute/function_test.cc @@ -53,6 +53,8 @@ TEST(FunctionOptions, Equality) { options.emplace_back(new IndexOptions(ScalarFromJSON(int64(), "16"))); options.emplace_back(new IndexOptions(ScalarFromJSON(boolean(), "true"))); options.emplace_back(new IndexOptions(ScalarFromJSON(boolean(), "null"))); + options.emplace_back(new ArithmeticOptions()); + options.emplace_back(new ArithmeticOptions(/*check_overflow=*/true)); options.emplace_back(new ElementWiseAggregateOptions()); options.emplace_back(new ElementWiseAggregateOptions(/*skip_nulls=*/false)); options.emplace_back(new JoinOptions()); diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index c8393103dc5..63e6fffc782 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -549,38 +549,10 @@ cdef class FunctionOptions(_Weakrefable): unique_ptr[CFunctionOptions] c_options c_options = move(GetResultValue(move(maybe_options))) type_name = frombytes(c_options.get().options_type().type_name()) - mapping = { - "array_sort": ArraySortOptions, - "cast": CastOptions, - "dictionary_encode": DictionaryEncodeOptions, - "element_wise_aggregate": ElementWiseAggregateOptions, - "extract_regex": ExtractRegexOptions, - "filter": FilterOptions, - "index": IndexOptions, - "join": JoinOptions, - "match_substring": MatchSubstringOptions, - "mode": ModeOptions, - "pad": PadOptions, - "partition_nth": PartitionNthOptions, - "project": ProjectOptions, - "quantile": QuantileOptions, - "replace_slice": ReplaceSliceOptions, - "replace_substring": ReplaceSubstringOptions, - "set_lookup": SetLookupOptions, - "scalar_aggregate": ScalarAggregateOptions, - "slice": SliceOptions, - "sort": SortOptions, - "split": SplitOptions, - "split_pattern": SplitPatternOptions, - "strptime": StrptimeOptions, - "t_digest": TDigestOptions, - "take": TakeOptions, - "trim": TrimOptions, - "variance": VarianceOptions, - } - if type_name not in mapping: + module = globals() + if type_name not in module: raise ValueError(f"Cannot deserialize '{type_name}'") - klass = mapping[type_name] + klass = module[type_name] options = klass.__new__(klass) ( options).init(move(c_options)) return options