From 50eb6d7f7f3caa45fb5282bc64ce20e0454d2585 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 1 Jul 2021 08:47:20 -0400
Subject: [PATCH] ARROW-13235: [C++][Python] Simplify mapping of function
options
---
cpp/src/arrow/compute/api_aggregate.h | 12 ++++-----
cpp/src/arrow/compute/api_scalar.cc | 7 +++++
cpp/src/arrow/compute/api_scalar.h | 36 +++++++++++++-------------
cpp/src/arrow/compute/api_vector.h | 12 ++++-----
cpp/src/arrow/compute/cast.h | 2 +-
cpp/src/arrow/compute/function_test.cc | 2 ++
python/pyarrow/_compute.pyx | 34 +++---------------------
7 files changed, 43 insertions(+), 62 deletions(-)
diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h
index 9be0b406aa4..7b6e2ef96de 100644
--- a/cpp/src/arrow/compute/api_aggregate.h
+++ b/cpp/src/arrow/compute/api_aggregate.h
@@ -46,7 +46,7 @@ class ExecContext;
class ARROW_EXPORT ScalarAggregateOptions : public FunctionOptions {
public:
explicit ScalarAggregateOptions(bool skip_nulls = true, uint32_t min_count = 1);
- constexpr static char const kTypeName[] = "scalar_aggregate";
+ constexpr static char const kTypeName[] = "ScalarAggregateOptions";
static ScalarAggregateOptions Defaults() { return ScalarAggregateOptions{}; }
bool skip_nulls;
@@ -60,7 +60,7 @@ class ARROW_EXPORT ScalarAggregateOptions : public FunctionOptions {
class ARROW_EXPORT ModeOptions : public FunctionOptions {
public:
explicit ModeOptions(int64_t n = 1);
- constexpr static char const kTypeName[] = "mode";
+ constexpr static char const kTypeName[] = "ModeOptions";
static ModeOptions Defaults() { return ModeOptions{}; }
int64_t n = 1;
@@ -73,7 +73,7 @@ class ARROW_EXPORT ModeOptions : public FunctionOptions {
class ARROW_EXPORT VarianceOptions : public FunctionOptions {
public:
explicit VarianceOptions(int ddof = 0);
- constexpr static char const kTypeName[] = "variance";
+ constexpr static char const kTypeName[] = "VarianceOptions";
static VarianceOptions Defaults() { return VarianceOptions{}; }
int ddof = 0;
@@ -98,7 +98,7 @@ class ARROW_EXPORT QuantileOptions : public FunctionOptions {
explicit QuantileOptions(std::vector q,
enum Interpolation interpolation = LINEAR);
- constexpr static char const kTypeName[] = "quantile";
+ constexpr static char const kTypeName[] = "QuantileOptions";
static QuantileOptions Defaults() { return QuantileOptions{}; }
/// quantile must be between 0 and 1 inclusive
@@ -115,7 +115,7 @@ class ARROW_EXPORT TDigestOptions : public FunctionOptions {
uint32_t buffer_size = 500);
explicit TDigestOptions(std::vector q, uint32_t delta = 100,
uint32_t buffer_size = 500);
- constexpr static char const kTypeName[] = "t_digest";
+ constexpr static char const kTypeName[] = "TDigestOptions";
static TDigestOptions Defaults() { return TDigestOptions{}; }
/// quantile must be between 0 and 1 inclusive
@@ -132,7 +132,7 @@ class ARROW_EXPORT IndexOptions : public FunctionOptions {
explicit IndexOptions(std::shared_ptr value);
// Default constructor for serialization
IndexOptions();
- constexpr static char const kTypeName[] = "index";
+ constexpr static char const kTypeName[] = "IndexOptions";
std::shared_ptr value;
};
diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc
index 11b5b45b7a0..2021c8a30c6 100644
--- a/cpp/src/arrow/compute/api_scalar.cc
+++ b/cpp/src/arrow/compute/api_scalar.cc
@@ -110,6 +110,8 @@ using ::arrow::internal::checked_cast;
namespace internal {
namespace {
using ::arrow::internal::DataMember;
+static auto kArithmeticOptionsType = GetFunctionOptionsType(
+ DataMember("check_overflow", &ArithmeticOptions::check_overflow));
static auto kElementWiseAggregateOptionsType =
GetFunctionOptionsType(
DataMember("skip_nulls", &ElementWiseAggregateOptions::skip_nulls));
@@ -159,6 +161,10 @@ static auto kProjectOptionsType = GetFunctionOptionsType(
} // namespace
} // namespace internal
+ArithmeticOptions::ArithmeticOptions(bool check_overflow)
+ : FunctionOptions(internal::kArithmeticOptionsType), check_overflow(check_overflow) {}
+constexpr char ArithmeticOptions::kTypeName[];
+
ElementWiseAggregateOptions::ElementWiseAggregateOptions(bool skip_nulls)
: FunctionOptions(internal::kElementWiseAggregateOptionsType),
skip_nulls(skip_nulls) {}
@@ -274,6 +280,7 @@ constexpr char ProjectOptions::kTypeName[];
namespace internal {
void RegisterScalarOptions(FunctionRegistry* registry) {
+ DCHECK_OK(registry->AddFunctionOptionsType(kArithmeticOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kElementWiseAggregateOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kJoinOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kMatchSubstringOptionsType));
diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h
index bacb287d6bc..89b4faca940 100644
--- a/cpp/src/arrow/compute/api_scalar.h
+++ b/cpp/src/arrow/compute/api_scalar.h
@@ -37,17 +37,17 @@ namespace compute {
///
/// @{
-struct ARROW_EXPORT ArithmeticOptions {
+class ARROW_EXPORT ArithmeticOptions : public FunctionOptions {
public:
- explicit ArithmeticOptions(bool check_overflow = false)
- : check_overflow(check_overflow) {}
+ explicit ArithmeticOptions(bool check_overflow = false);
+ constexpr static char const kTypeName[] = "ArithmeticOptions";
bool check_overflow;
};
class ARROW_EXPORT ElementWiseAggregateOptions : public FunctionOptions {
public:
explicit ElementWiseAggregateOptions(bool skip_nulls = true);
- constexpr static char const kTypeName[] = "element_wise_aggregate";
+ constexpr static char const kTypeName[] = "ElementWiseAggregateOptions";
static ElementWiseAggregateOptions Defaults() { return ElementWiseAggregateOptions{}; }
bool skip_nulls;
@@ -67,7 +67,7 @@ class ARROW_EXPORT JoinOptions : public FunctionOptions {
};
explicit JoinOptions(NullHandlingBehavior null_handling = EMIT_NULL,
std::string null_replacement = "");
- constexpr static char const kTypeName[] = "join";
+ constexpr static char const kTypeName[] = "JoinOptions";
static JoinOptions Defaults() { return JoinOptions(); }
NullHandlingBehavior null_handling;
std::string null_replacement;
@@ -77,7 +77,7 @@ class ARROW_EXPORT MatchSubstringOptions : public FunctionOptions {
public:
explicit MatchSubstringOptions(std::string pattern, bool ignore_case = false);
MatchSubstringOptions();
- constexpr static char const kTypeName[] = "match_substring";
+ constexpr static char const kTypeName[] = "MatchSubstringOptions";
/// The exact substring (or regex, depending on kernel) to look for inside input values.
std::string pattern;
@@ -88,7 +88,7 @@ class ARROW_EXPORT MatchSubstringOptions : public FunctionOptions {
class ARROW_EXPORT SplitOptions : public FunctionOptions {
public:
explicit SplitOptions(int64_t max_splits = -1, bool reverse = false);
- constexpr static char const kTypeName[] = "split";
+ constexpr static char const kTypeName[] = "SplitOptions";
/// Maximum number of splits allowed, or unlimited when -1
int64_t max_splits;
@@ -101,7 +101,7 @@ class ARROW_EXPORT SplitPatternOptions : public FunctionOptions {
explicit SplitPatternOptions(std::string pattern, int64_t max_splits = -1,
bool reverse = false);
SplitPatternOptions();
- constexpr static char const kTypeName[] = "split_pattern";
+ constexpr static char const kTypeName[] = "SplitPatternOptions";
/// The exact substring to split on.
std::string pattern;
@@ -115,7 +115,7 @@ class ARROW_EXPORT ReplaceSliceOptions : public FunctionOptions {
public:
explicit ReplaceSliceOptions(int64_t start, int64_t stop, std::string replacement);
ReplaceSliceOptions();
- constexpr static char const kTypeName[] = "replace_slice";
+ constexpr static char const kTypeName[] = "ReplaceSliceOptions";
/// Index to start slicing at
int64_t start;
@@ -130,7 +130,7 @@ class ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions {
explicit ReplaceSubstringOptions(std::string pattern, std::string replacement,
int64_t max_replacements = -1);
ReplaceSubstringOptions();
- constexpr static char const kTypeName[] = "replace_substring";
+ constexpr static char const kTypeName[] = "ReplaceSubstringOptions";
/// Pattern to match, literal, or regular expression depending on which kernel is used
std::string pattern;
@@ -144,7 +144,7 @@ class ARROW_EXPORT ExtractRegexOptions : public FunctionOptions {
public:
explicit ExtractRegexOptions(std::string pattern);
ExtractRegexOptions();
- constexpr static char const kTypeName[] = "extract_regex";
+ constexpr static char const kTypeName[] = "ExtractRegexOptions";
/// Regular expression with named capture fields
std::string pattern;
@@ -155,7 +155,7 @@ class ARROW_EXPORT SetLookupOptions : public FunctionOptions {
public:
explicit SetLookupOptions(Datum value_set, bool skip_nulls = false);
SetLookupOptions();
- constexpr static char const kTypeName[] = "set_lookup";
+ constexpr static char const kTypeName[] = "SetLookupOptions";
/// The set of values to look up input values into.
Datum value_set;
@@ -172,7 +172,7 @@ class ARROW_EXPORT StrptimeOptions : public FunctionOptions {
public:
explicit StrptimeOptions(std::string format, TimeUnit::type unit);
StrptimeOptions();
- constexpr static char const kTypeName[] = "strptime";
+ constexpr static char const kTypeName[] = "StrptimeOptions";
std::string format;
TimeUnit::type unit;
@@ -182,7 +182,7 @@ class ARROW_EXPORT PadOptions : public FunctionOptions {
public:
explicit PadOptions(int64_t width, std::string padding = " ");
PadOptions();
- constexpr static char const kTypeName[] = "pad";
+ constexpr static char const kTypeName[] = "PadOptions";
/// The desired string length.
int64_t width;
@@ -194,7 +194,7 @@ class ARROW_EXPORT TrimOptions : public FunctionOptions {
public:
explicit TrimOptions(std::string characters);
TrimOptions();
- constexpr static char const kTypeName[] = "trim";
+ constexpr static char const kTypeName[] = "TrimOptions";
/// The individual characters that can be trimmed from the string.
std::string characters;
@@ -205,7 +205,7 @@ class ARROW_EXPORT SliceOptions : public FunctionOptions {
explicit SliceOptions(int64_t start, int64_t stop = std::numeric_limits::max(),
int64_t step = 1);
SliceOptions();
- constexpr static char const kTypeName[] = "slice";
+ constexpr static char const kTypeName[] = "SliceOptions";
int64_t start, stop, step;
};
@@ -222,7 +222,7 @@ class ARROW_EXPORT CompareOptions : public FunctionOptions {
public:
explicit CompareOptions(CompareOperator op);
CompareOptions();
- constexpr static char const kTypeName[] = "compare";
+ constexpr static char const kTypeName[] = "CompareOptions";
enum CompareOperator op;
};
@@ -232,7 +232,7 @@ class ARROW_EXPORT ProjectOptions : public FunctionOptions {
std::vector> m);
explicit ProjectOptions(std::vector n);
ProjectOptions();
- constexpr static char const kTypeName[] = "project";
+ constexpr static char const kTypeName[] = "ProjectOptions";
/// Names for wrapped columns
std::vector field_names;
diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h
index 2282b0098f9..6021492320e 100644
--- a/cpp/src/arrow/compute/api_vector.h
+++ b/cpp/src/arrow/compute/api_vector.h
@@ -43,7 +43,7 @@ class ARROW_EXPORT FilterOptions : public FunctionOptions {
};
explicit FilterOptions(NullSelectionBehavior null_selection = DROP);
- constexpr static char const kTypeName[] = "filter";
+ constexpr static char const kTypeName[] = "FilterOptions";
static FilterOptions Defaults() { return FilterOptions(); }
NullSelectionBehavior null_selection_behavior = DROP;
@@ -52,7 +52,7 @@ class ARROW_EXPORT FilterOptions : public FunctionOptions {
class ARROW_EXPORT TakeOptions : public FunctionOptions {
public:
explicit TakeOptions(bool boundscheck = true);
- constexpr static char const kTypeName[] = "take";
+ constexpr static char const kTypeName[] = "TakeOptions";
static TakeOptions BoundsCheck() { return TakeOptions(true); }
static TakeOptions NoBoundsCheck() { return TakeOptions(false); }
static TakeOptions Defaults() { return BoundsCheck(); }
@@ -72,7 +72,7 @@ class ARROW_EXPORT DictionaryEncodeOptions : public FunctionOptions {
};
explicit DictionaryEncodeOptions(NullEncodingBehavior null_encoding = MASK);
- constexpr static char const kTypeName[] = "dictionary_encode";
+ constexpr static char const kTypeName[] = "DictionaryEncodeOptions";
static DictionaryEncodeOptions Defaults() { return DictionaryEncodeOptions(); }
NullEncodingBehavior null_encoding_behavior = MASK;
@@ -104,7 +104,7 @@ class ARROW_EXPORT SortKey : public util::EqualityComparable {
class ARROW_EXPORT ArraySortOptions : public FunctionOptions {
public:
explicit ArraySortOptions(SortOrder order = SortOrder::Ascending);
- constexpr static char const kTypeName[] = "array_sort";
+ constexpr static char const kTypeName[] = "ArraySortOptions";
static ArraySortOptions Defaults() { return ArraySortOptions{}; }
SortOrder order;
@@ -113,7 +113,7 @@ class ARROW_EXPORT ArraySortOptions : public FunctionOptions {
class ARROW_EXPORT SortOptions : public FunctionOptions {
public:
explicit SortOptions(std::vector sort_keys = {});
- constexpr static char const kTypeName[] = "sort";
+ constexpr static char const kTypeName[] = "SortOptions";
static SortOptions Defaults() { return SortOptions{}; }
std::vector sort_keys;
@@ -124,7 +124,7 @@ class ARROW_EXPORT PartitionNthOptions : public FunctionOptions {
public:
explicit PartitionNthOptions(int64_t pivot);
PartitionNthOptions() : PartitionNthOptions(0) {}
- constexpr static char const kTypeName[] = "partition_nth";
+ constexpr static char const kTypeName[] = "PartitionNthOptions";
/// The index into the equivalent sorted array of the partition pivot element.
int64_t pivot;
diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h
index 8abd2a71bca..131f57f892f 100644
--- a/cpp/src/arrow/compute/cast.h
+++ b/cpp/src/arrow/compute/cast.h
@@ -45,7 +45,7 @@ class ARROW_EXPORT CastOptions : public FunctionOptions {
public:
explicit CastOptions(bool safe = true);
- constexpr static char const kTypeName[] = "cast";
+ constexpr static char const kTypeName[] = "CastOptions";
static CastOptions Safe(std::shared_ptr to_type = NULLPTR) {
CastOptions safe(true);
safe.to_type = std::move(to_type);
diff --git a/cpp/src/arrow/compute/function_test.cc b/cpp/src/arrow/compute/function_test.cc
index 4c42ce39600..bbe514af09a 100644
--- a/cpp/src/arrow/compute/function_test.cc
+++ b/cpp/src/arrow/compute/function_test.cc
@@ -53,6 +53,8 @@ TEST(FunctionOptions, Equality) {
options.emplace_back(new IndexOptions(ScalarFromJSON(int64(), "16")));
options.emplace_back(new IndexOptions(ScalarFromJSON(boolean(), "true")));
options.emplace_back(new IndexOptions(ScalarFromJSON(boolean(), "null")));
+ options.emplace_back(new ArithmeticOptions());
+ options.emplace_back(new ArithmeticOptions(/*check_overflow=*/true));
options.emplace_back(new ElementWiseAggregateOptions());
options.emplace_back(new ElementWiseAggregateOptions(/*skip_nulls=*/false));
options.emplace_back(new JoinOptions());
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index c8393103dc5..63e6fffc782 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -549,38 +549,10 @@ cdef class FunctionOptions(_Weakrefable):
unique_ptr[CFunctionOptions] c_options
c_options = move(GetResultValue(move(maybe_options)))
type_name = frombytes(c_options.get().options_type().type_name())
- mapping = {
- "array_sort": ArraySortOptions,
- "cast": CastOptions,
- "dictionary_encode": DictionaryEncodeOptions,
- "element_wise_aggregate": ElementWiseAggregateOptions,
- "extract_regex": ExtractRegexOptions,
- "filter": FilterOptions,
- "index": IndexOptions,
- "join": JoinOptions,
- "match_substring": MatchSubstringOptions,
- "mode": ModeOptions,
- "pad": PadOptions,
- "partition_nth": PartitionNthOptions,
- "project": ProjectOptions,
- "quantile": QuantileOptions,
- "replace_slice": ReplaceSliceOptions,
- "replace_substring": ReplaceSubstringOptions,
- "set_lookup": SetLookupOptions,
- "scalar_aggregate": ScalarAggregateOptions,
- "slice": SliceOptions,
- "sort": SortOptions,
- "split": SplitOptions,
- "split_pattern": SplitPatternOptions,
- "strptime": StrptimeOptions,
- "t_digest": TDigestOptions,
- "take": TakeOptions,
- "trim": TrimOptions,
- "variance": VarianceOptions,
- }
- if type_name not in mapping:
+ module = globals()
+ if type_name not in module:
raise ValueError(f"Cannot deserialize '{type_name}'")
- klass = mapping[type_name]
+ klass = module[type_name]
options = klass.__new__(klass)
( options).init(move(c_options))
return options