From c188236a8fee4fee926b49f409100c5582850905 Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 11 Jun 2021 13:56:35 -0400
Subject: [PATCH 1/9] ARROW-12709: [C++] Add var_args_join
---
cpp/src/arrow/compute/api_scalar.h | 19 ++
.../arrow/compute/kernels/scalar_string.cc | 232 ++++++++++++++++++
.../compute/kernels/scalar_string_test.cc | 116 +++++++++
docs/source/cpp/compute.rst | 16 +-
docs/source/python/api/compute.rst | 17 ++
python/pyarrow/_compute.pyx | 31 +++
python/pyarrow/compute.py | 1 +
python/pyarrow/includes/libarrow.pxd | 16 ++
python/pyarrow/tests/test_compute.py | 25 ++
9 files changed, 468 insertions(+), 5 deletions(-)
diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h
index 6e9a9340f2c..d6fd7075df2 100644
--- a/cpp/src/arrow/compute/api_scalar.h
+++ b/cpp/src/arrow/compute/api_scalar.h
@@ -48,6 +48,25 @@ struct ARROW_EXPORT ElementWiseAggregateOptions : public FunctionOptions {
bool skip_nulls;
};
+/// Options for var_args_join.
+struct ARROW_EXPORT JoinOptions : public FunctionOptions {
+ /// How to handle null values. (A null separator always results in a null output.)
+ enum NullHandlingBehavior {
+ /// A null in any input results in a null in the output.
+ EMIT_NULL,
+ /// Nulls in inputs are skipped.
+ SKIP,
+ /// Nulls in inputs are replaced with the replacement string.
+ REPLACE,
+ };
+ explicit JoinOptions(NullHandlingBehavior null_handling = EMIT_NULL,
+ std::string null_replacement = "")
+ : null_handling(null_handling), null_replacement(std::move(null_replacement)) {}
+ static JoinOptions Defaults() { return JoinOptions(); }
+ NullHandlingBehavior null_handling;
+ std::string null_replacement;
+};
+
struct ARROW_EXPORT MatchSubstringOptions : public FunctionOptions {
explicit MatchSubstringOptions(std::string pattern, bool ignore_case = false)
: pattern(std::move(pattern)), ignore_case(ignore_case) {}
diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc
index cd054fcea0e..dd741eda713 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string.cc
@@ -3367,6 +3367,237 @@ void AddBinaryJoin(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}
+using VarArgsJoinState = OptionsWrapper;
+
+template
+struct VarArgsJoin {
+ using ArrayType = typename TypeTraits::ArrayType;
+ using BuilderType = typename TypeTraits::BuilderType;
+ using offset_type = typename Type::offset_type;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ JoinOptions options = VarArgsJoinState::Get(ctx);
+ // Last argument is the separator (for consistency with binary_join)
+ if (std::all_of(batch.values.begin(), batch.values.end(),
+ [](const Datum& d) { return d.is_scalar(); })) {
+ return ExecOnlyScalar(ctx, options, batch, out);
+ }
+ return ExecContainingArrays(ctx, options, batch, out);
+ }
+
+ static Status ExecOnlyScalar(KernelContext* ctx, const JoinOptions& options,
+ const ExecBatch& batch, Datum* out) {
+ BaseBinaryScalar* output = checked_cast(out->scalar().get());
+ const size_t num_args = batch.values.size();
+ if (num_args == 1) {
+ // Only separator, no values
+ ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(0));
+ output->is_valid = batch.values[0].scalar()->is_valid;
+ return Status::OK();
+ }
+
+ int64_t final_size = CalculateRowSize(options, batch, 0);
+ if (final_size < 0) {
+ ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(0));
+ output->is_valid = false;
+ return Status::OK();
+ }
+ ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(final_size));
+ const auto separator = UnboxScalar::Unbox(*batch.values.back().scalar());
+ uint8_t* buf = output->value->mutable_data();
+ bool first = true;
+ for (size_t i = 0; i < num_args - 1; i++) {
+ const Scalar& scalar = *batch[i].scalar();
+ util::string_view s;
+ if (scalar.is_valid) {
+ s = UnboxScalar::Unbox(scalar);
+ } else {
+ switch (options.null_handling) {
+ case JoinOptions::EMIT_NULL:
+ // Handled by CalculateRowSize
+ DCHECK(false) << "unreachable";
+ break;
+ case JoinOptions::SKIP:
+ continue;
+ case JoinOptions::REPLACE:
+ s = options.null_replacement;
+ break;
+ }
+ }
+ if (!first) {
+ buf = std::copy(separator.begin(), separator.end(), buf);
+ }
+ first = false;
+ buf = std::copy(s.begin(), s.end(), buf);
+ }
+ output->is_valid = true;
+ return Status::OK();
+ }
+
+ static Status ExecContainingArrays(KernelContext* ctx, const JoinOptions& options,
+ const ExecBatch& batch, Datum* out) {
+ // Presize data to avoid reallocations
+ int64_t final_size = 0;
+ for (int64_t i = 0; i < batch.length; i++) {
+ auto size = CalculateRowSize(options, batch, i);
+ if (size > 0) final_size += size;
+ }
+ BuilderType builder(ctx->memory_pool());
+ RETURN_NOT_OK(builder.Reserve(batch.length));
+ RETURN_NOT_OK(builder.ReserveData(final_size));
+
+ std::vector valid_cols(batch.values.size());
+ for (size_t row = 0; row < static_cast(batch.length); row++) {
+ size_t num_valid = 0; // Not counting separator
+ for (size_t col = 0; col < batch.values.size(); col++) {
+ bool valid = false;
+ if (batch[col].is_scalar()) {
+ valid = batch[col].scalar()->is_valid;
+ } else {
+ const ArrayData& array = *batch[col].array();
+ valid = !array.MayHaveNulls() ||
+ BitUtil::GetBit(array.buffers[0]->data(), array.offset + row);
+ }
+ if (valid) {
+ valid_cols[col] = &batch[col];
+ if (col < batch.values.size() - 1) num_valid++;
+ } else {
+ valid_cols[col] = nullptr;
+ }
+ }
+
+ if (!valid_cols.back()) {
+ // Separator is null
+ builder.UnsafeAppendNull();
+ continue;
+ } else if (batch.values.size() == 1) {
+ // Only given separator
+ builder.UnsafeAppendEmptyValue();
+ continue;
+ } else if (num_valid < batch.values.size() - 1) {
+ // We had some nulls
+ if (options.null_handling == JoinOptions::EMIT_NULL) {
+ builder.UnsafeAppendNull();
+ continue;
+ }
+ }
+ const auto separator = Lookup(*valid_cols.back(), row);
+ bool first = true;
+ for (size_t col = 0; col < batch.values.size() - 1; col++) {
+ const Datum* datum = valid_cols[col];
+ util::string_view value;
+ if (!datum) {
+ switch (options.null_handling) {
+ case JoinOptions::EMIT_NULL:
+ DCHECK(false) << "unreachable";
+ break;
+ case JoinOptions::SKIP:
+ continue;
+ case JoinOptions::REPLACE:
+ value = options.null_replacement;
+ break;
+ }
+ } else {
+ value = Lookup(*datum, row);
+ }
+ if (first) {
+ builder.UnsafeAppend(value);
+ first = false;
+ continue;
+ }
+ builder.UnsafeExtendCurrent(separator);
+ builder.UnsafeExtendCurrent(value);
+ }
+ }
+
+ std::shared_ptr string_array;
+ RETURN_NOT_OK(builder.Finish(&string_array));
+ *out = *string_array->data();
+ out->mutable_array()->type = batch[0].type();
+ DCHECK_EQ(batch.length, out->array()->length);
+ return Status::OK();
+ }
+
+ // Unbox a scalar or the given element of an array.
+ static util::string_view Lookup(const Datum& datum, size_t row) {
+ if (datum.is_scalar()) {
+ return UnboxScalar::Unbox(*datum.scalar());
+ }
+ const ArrayData& array = *datum.array();
+ const offset_type* offsets = array.GetValues(1);
+ const uint8_t* data = array.GetValues(2, /*absolute_offset=*/0);
+ const int64_t length = offsets[row + 1] - offsets[row];
+ return util::string_view(reinterpret_cast(data + offsets[row]), length);
+ }
+
+ // Compute the length of the output for the given position, or -1 if it would be null.
+ static int64_t CalculateRowSize(const JoinOptions& options, const ExecBatch& batch,
+ const int64_t index) {
+ const auto num_args = batch.values.size();
+ int64_t final_size = 0;
+ int64_t num_non_null_args = 0;
+ for (size_t i = 0; i < num_args; i++) {
+ int64_t element_size = 0;
+ bool valid = true;
+ if (batch[i].is_scalar()) {
+ const Scalar& scalar = *batch[i].scalar();
+ valid = scalar.is_valid;
+ element_size = UnboxScalar::Unbox(scalar).size();
+ } else {
+ const ArrayData& array = *batch[i].array();
+ valid = !array.MayHaveNulls() ||
+ BitUtil::GetBit(array.buffers[0]->data(), array.offset + index);
+ const offset_type* offsets = array.GetValues(1);
+ element_size = offsets[index + 1] - offsets[index];
+ }
+ if (i == num_args - 1) {
+ if (!valid) return -1;
+ if (num_non_null_args > 1) {
+ // Add separator size (only if there were values to join)
+ final_size += (num_non_null_args - 1) * element_size;
+ }
+ break;
+ }
+ if (!valid) {
+ switch (options.null_handling) {
+ case JoinOptions::EMIT_NULL:
+ return -1;
+ case JoinOptions::SKIP:
+ continue;
+ case JoinOptions::REPLACE:
+ element_size = options.null_replacement.size();
+ break;
+ }
+ }
+ num_non_null_args++;
+ final_size += element_size;
+ }
+ return final_size;
+ }
+};
+
+const FunctionDoc var_args_join_doc(
+ "Join string arguments into one, using the last argument as the separator",
+ ("Insert the last argument of `strings` between the rest of the elements, "
+ "and concatenate them.\n"
+ "Any null separator element emits a null output. Null elements either "
+ "emit a null (the default), are skipped, or replaced with a given string.\n"),
+ {"*strings"}, "JoinOptions");
+
+const auto kDefaultJoinOptions = JoinOptions::Defaults();
+
+void AddVarArgsJoin(FunctionRegistry* registry) {
+ auto func =
+ std::make_shared("var_args_join", Arity::VarArgs(/*min_args=*/1),
+ &var_args_join_doc, &kDefaultJoinOptions);
+ for (const auto& ty : BaseBinaryTypes()) {
+ DCHECK_OK(func->AddKernel({InputType(ty)}, ty,
+ GenerateTypeAgnosticVarBinaryBase(ty),
+ VarArgsJoinState::Init));
+ }
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
template class ExecFunctor>
void MakeUnaryStringBatchKernel(
std::string name, FunctionRegistry* registry, const FunctionDoc* doc,
@@ -3675,6 +3906,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) {
AddSplit(registry);
AddStrptime(registry);
AddBinaryJoin(registry);
+ AddVarArgsJoin(registry);
}
} // namespace internal
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
index 2053dbaa971..d44b6a89b7f 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
@@ -58,6 +58,26 @@ class BaseTestStringKernels : public ::testing::Test {
json_expected, options);
}
+ void CheckVarArgsScalar(std::string func_name, std::string json_input,
+ std::shared_ptr out_ty, std::string json_expected,
+ const FunctionOptions* options = nullptr) {
+ // CheckScalar (on arrays) checks scalar arguments individually,
+ // but this lets us test the all-scalar case explicitly
+ ScalarVector inputs;
+ std::shared_ptr args = ArrayFromJSON(type(), json_input);
+ for (int64_t i = 0; i < args->length(); i++) {
+ ASSERT_OK_AND_ASSIGN(auto scalar, args->GetScalar(i));
+ inputs.push_back(std::move(scalar));
+ }
+ CheckScalar(func_name, inputs, ScalarFromJSON(out_ty, json_expected), options);
+ }
+
+ void CheckVarArgs(std::string func_name, const std::vector& inputs,
+ std::shared_ptr out_ty, std::string json_expected,
+ const FunctionOptions* options = nullptr) {
+ CheckScalar(func_name, inputs, ArrayFromJSON(out_ty, json_expected), options);
+ }
+
std::shared_ptr type() { return TypeTraits::type_singleton(); }
template
@@ -229,6 +249,102 @@ TYPED_TEST(TestBinaryKernels, CountSubstringIgnoreCase) {
}
#endif
+TYPED_TEST(TestBinaryKernels, VarArgsJoin) {
+ const auto ty = this->type();
+ JoinOptions options;
+ JoinOptions options_skip(JoinOptions::SKIP);
+ JoinOptions options_replace(JoinOptions::REPLACE, "X");
+ // Scalar args, Scalar separator
+ this->CheckVarArgsScalar("var_args_join", R"([null])", ty, R"(null)", &options);
+ this->CheckVarArgsScalar("var_args_join", R"(["-"])", ty, R"("")", &options);
+ this->CheckVarArgsScalar("var_args_join", R"(["a", "-"])", ty, R"("a")", &options);
+ this->CheckVarArgsScalar("var_args_join", R"(["a", "b", "-"])", ty, R"("a-b")",
+ &options);
+ this->CheckVarArgsScalar("var_args_join", R"(["a", "b", null])", ty, R"(null)",
+ &options);
+ this->CheckVarArgsScalar("var_args_join", R"(["a", null, "-"])", ty, R"(null)",
+ &options);
+ this->CheckVarArgsScalar("var_args_join", R"(["foo", "bar", "baz", "++"])", ty,
+ R"("foo++bar++baz")", &options);
+
+ // Scalar args, Array separator
+ const auto sep = ArrayFromJSON(ty, R"([null, "-", "--"])");
+ const auto scalar1 = ScalarFromJSON(ty, R"("foo")");
+ const auto scalar2 = ScalarFromJSON(ty, R"("bar")");
+ const auto scalar3 = ScalarFromJSON(ty, R"("")");
+ const auto scalar_null = ScalarFromJSON(ty, R"(null)");
+ this->CheckVarArgs("var_args_join", {sep}, ty, R"([null, "", ""])", &options);
+ this->CheckVarArgs("var_args_join", {scalar1, sep}, ty, R"([null, "foo", "foo"])",
+ &options);
+ this->CheckVarArgs("var_args_join", {scalar1, scalar2, sep}, ty,
+ R"([null, "foo-bar", "foo--bar"])", &options);
+ this->CheckVarArgs("var_args_join", {scalar1, scalar_null, sep}, ty,
+ R"([null, null, null])", &options);
+ this->CheckVarArgs("var_args_join", {scalar1, scalar2, scalar3, sep}, ty,
+ R"([null, "foo-bar-", "foo--bar--"])", &options);
+
+ // Array args, Scalar separator
+ const auto sep1 = ScalarFromJSON(ty, R"("-")");
+ const auto sep2 = ScalarFromJSON(ty, R"("--")");
+ const auto arr1 = ArrayFromJSON(ty, R"([null, "a", "bb", "ccc"])");
+ const auto arr2 = ArrayFromJSON(ty, R"(["d", null, "e", ""])");
+ const auto arr3 = ArrayFromJSON(ty, R"(["gg", null, "h", "iii"])");
+ this->CheckVarArgs("var_args_join", {arr1, arr2, arr3, scalar_null}, ty,
+ R"([null, null, null, null])", &options);
+ this->CheckVarArgs("var_args_join", {arr1, arr2, arr3, sep1}, ty,
+ R"([null, null, "bb-e-h", "ccc--iii"])", &options);
+ this->CheckVarArgs("var_args_join", {arr1, arr2, arr3, sep2}, ty,
+ R"([null, null, "bb--e--h", "ccc----iii"])", &options);
+
+ // Array args, Array separator
+ const auto sep3 = ArrayFromJSON(ty, R"(["-", "--", null, "---"])");
+ this->CheckVarArgs("var_args_join", {arr1, arr2, arr3, sep3}, ty,
+ R"([null, null, null, "ccc------iii"])", &options);
+
+ // Mixed
+ this->CheckVarArgs("var_args_join", {arr1, arr2, scalar2, sep3}, ty,
+ R"([null, null, null, "ccc------bar"])", &options);
+ this->CheckVarArgs("var_args_join", {arr1, arr2, scalar_null, sep3}, ty,
+ R"([null, null, null, null])", &options);
+ this->CheckVarArgs("var_args_join", {arr1, arr2, scalar2, sep1}, ty,
+ R"([null, null, "bb-e-bar", "ccc--bar"])", &options);
+ this->CheckVarArgs("var_args_join", {arr1, arr2, scalar_null, scalar_null}, ty,
+ R"([null, null, null, null])", &options);
+
+ // Skip
+ this->CheckVarArgsScalar("var_args_join", R"(["a", null, "b", "-"])", ty, R"("a-b")",
+ &options_skip);
+ this->CheckVarArgsScalar("var_args_join", R"(["a", null, "b", null])", ty, R"(null)",
+ &options_skip);
+ this->CheckVarArgs("var_args_join", {arr1, arr2, scalar2, sep3}, ty,
+ R"(["d-bar", "a--bar", null, "ccc------bar"])", &options_skip);
+ this->CheckVarArgs("var_args_join", {arr1, arr2, scalar_null, sep3}, ty,
+ R"(["d", "a", null, "ccc---"])", &options_skip);
+ this->CheckVarArgs("var_args_join", {arr1, arr2, scalar2, sep1}, ty,
+ R"(["d-bar", "a-bar", "bb-e-bar", "ccc--bar"])", &options_skip);
+ this->CheckVarArgs("var_args_join", {arr1, arr2, scalar_null, scalar_null}, ty,
+ R"([null, null, null, null])", &options_skip);
+
+ // Replace
+ this->CheckVarArgsScalar("var_args_join", R"(["a", null, "b", "-"])", ty, R"("a-X-b")",
+ &options_replace);
+ this->CheckVarArgsScalar("var_args_join", R"(["a", null, "b", null])", ty, R"(null)",
+ &options_replace);
+ this->CheckVarArgs("var_args_join", {arr1, arr2, scalar2, sep3}, ty,
+ R"(["X-d-bar", "a--X--bar", null, "ccc------bar"])",
+ &options_replace);
+ this->CheckVarArgs("var_args_join", {arr1, arr2, scalar_null, sep3}, ty,
+ R"(["X-d-X", "a--X--X", null, "ccc------X"])", &options_replace);
+ this->CheckVarArgs("var_args_join", {arr1, arr2, scalar2, sep1}, ty,
+ R"(["X-d-bar", "a-X-bar", "bb-e-bar", "ccc--bar"])",
+ &options_replace);
+ this->CheckVarArgs("var_args_join", {arr1, arr2, scalar_null, scalar_null}, ty,
+ R"([null, null, null, null])", &options_replace);
+
+ // Error cases
+ ASSERT_RAISES(Invalid, CallFunction("var_args_join", {}, &options));
+}
+
template
class TestStringKernels : public BaseTestStringKernels {};
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 91ee6bdf599..0c709ca0f8c 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -682,17 +682,23 @@ String joining
This function does the inverse of string splitting.
-+-----------------+-----------+----------------------+----------------+-------------------+---------+
-| Function name | Arity | Input type 1 | Input type 2 | Output type | Notes |
-+=================+===========+======================+================+===================+=========+
-| binary_join | Binary | List of string-like | String-like | String-like | \(1) |
-+-----------------+-----------+----------------------+----------------+-------------------+---------+
++-----------------+-----------+----------------------+----------------+-------------------+-----------------------+---------+
+| Function name | Arity | Input type 1 | Input type 2 | Output type | Options class | Notes |
++=================+===========+======================+================+===================+=======================+=========+
+| binary_join | Binary | List of string-like | String-like | String-like | | \(1) |
++-----------------+-----------+----------------------+----------------+-------------------+-----------------------+---------+
+| var_args_join | Varargs | List of string-like | (NA) | String-like | :struct:`JoinOptions` | \(2) |
++-----------------+-----------+----------------------+----------------+-------------------+-----------------------+---------+
* \(1) The first input must be an array, while the second can be a scalar or array.
Each list of values in the first input is joined using each second input
as separator. If any input list is null or contains a null, the corresponding
output will be null.
+* \(2) All arguments are concatenated element-wise, with the last argument treated
+ as the separator (scalars are recycled in either case). Null separators emit
+ null. If any other argument is null, by default the corresponding output will be
+ null, but it can instead either be skipped or replaced with a given string.
Slicing
~~~~~~~
diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst
index dd722e44f05..ae231799b56 100644
--- a/docs/source/python/api/compute.rst
+++ b/docs/source/python/api/compute.rst
@@ -159,6 +159,23 @@ String Splitting
ascii_split_whitespace
utf8_split_whitespace
+String Component Extraction
+---------------------------
+
+.. autosummary::
+ :toctree: ../generated/
+
+ extract_regex
+
+String Joining
+--------------
+
+.. autosummary::
+ :toctree: ../generated/
+
+ binary_join
+ var_args_join
+
String Transforms
-----------------
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index 104cd1bac1f..559a8a02b1c 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -667,6 +667,37 @@ class ElementWiseAggregateOptions(_ElementWiseAggregateOptions):
self._set_options(skip_nulls)
+cdef class _JoinOptions(FunctionOptions):
+ cdef:
+ unique_ptr[CJoinOptions] join_options
+
+ cdef const CFunctionOptions* get_options(self) except NULL:
+ return self.join_options.get()
+
+ def _set_options(self, null_handling, null_replacement):
+ cdef:
+ CJoinNullHandlingBehavior c_null_handling = \
+ CJoinNullHandlingBehavior_EMIT_NULL
+ c_string c_null_replacement = tobytes(null_replacement)
+ if null_handling == 'emit_null':
+ c_null_handling = CJoinNullHandlingBehavior_EMIT_NULL
+ elif null_handling == 'skip':
+ c_null_handling = CJoinNullHandlingBehavior_SKIP
+ elif null_handling == 'replace':
+ c_null_handling = CJoinNullHandlingBehavior_REPLACE
+ else:
+ raise ValueError(
+ '"{}" is not a valid null_handling'
+ .format(null_handling))
+ self.join_options.reset(
+ new CJoinOptions(c_null_handling, c_null_replacement))
+
+
+class JoinOptions(_JoinOptions):
+ def __init__(self, null_handling='emit_null', null_replacement=''):
+ self._set_options(null_handling, null_replacement)
+
+
cdef class _MatchSubstringOptions(FunctionOptions):
cdef:
unique_ptr[CMatchSubstringOptions] match_substring_options
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index b8bd9e65f17..b258b551f02 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -36,6 +36,7 @@
ExtractRegexOptions,
FilterOptions,
IndexOptions,
+ JoinOptions,
MatchSubstringOptions,
ModeOptions,
PartitionNthOptions,
diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd
index 35a2034eba4..29baf49180d 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -1787,6 +1787,22 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
CElementWiseAggregateOptions(c_bool skip_nulls)
c_bool skip_nulls
+ enum CJoinNullHandlingBehavior \
+ "arrow::compute::JoinOptions::NullHandlingBehavior":
+ CJoinNullHandlingBehavior_EMIT_NULL \
+ "arrow::compute::JoinOptions::EMIT_NULL"
+ CJoinNullHandlingBehavior_SKIP \
+ "arrow::compute::JoinOptions::SKIP"
+ CJoinNullHandlingBehavior_REPLACE \
+ "arrow::compute::JoinOptions::REPLACE"
+
+ cdef cppclass CJoinOptions \
+ "arrow::compute::JoinOptions"(CFunctionOptions):
+ CJoinOptions(CJoinNullHandlingBehavior null_handling,
+ c_string null_replacement)
+ CJoinNullHandlingBehavior null_handling
+ c_string null_replacement
+
cdef cppclass CMatchSubstringOptions \
"arrow::compute::MatchSubstringOptions"(CFunctionOptions):
CMatchSubstringOptions(c_string pattern, c_bool ignore_case)
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index 1ed582db831..1f240dbcc6f 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -766,6 +766,31 @@ def test_binary_join():
assert pc.binary_join(ar_list, separator_array).equals(expected)
+def test_var_args_join():
+ null = pa.scalar(None, type=pa.string())
+ arrs = [[None, 'a', 'b'], ['c', None, 'd'], [None, '-', '--']]
+ assert pc.var_args_join(*arrs).to_pylist() == \
+ [None, None, 'b--d']
+ assert pc.var_args_join('a', 'b', '-').as_py() == 'a-b'
+ assert pc.var_args_join('a', null, '-').as_py() is None
+ assert pc.var_args_join('a', 'b', null).as_py() is None
+
+ skip = pc.JoinOptions('skip')
+ assert pc.var_args_join(*arrs, options=skip).to_pylist() == \
+ [None, 'a', 'b--d']
+ assert pc.var_args_join('a', 'b', '-', options=skip).as_py() == 'a-b'
+ assert pc.var_args_join('a', null, '-', options=skip).as_py() == 'a'
+ assert pc.var_args_join('a', 'b', null, options=skip).as_py() is None
+
+ replace = pc.JoinOptions('replace', null_replacement='spam')
+ assert pc.var_args_join(*arrs, options=replace).to_pylist() == \
+ [None, 'a-spam', 'b--d']
+ assert pc.var_args_join('a', 'b', '-', options=replace).as_py() == 'a-b'
+ assert pc.var_args_join(
+ 'a', null, '-', options=replace).as_py() == 'a-spam'
+ assert pc.var_args_join('a', 'b', null, options=replace).as_py() is None
+
+
@pytest.mark.parametrize(('ty', 'values'), all_array_types)
def test_take(ty, values):
arr = pa.array(values, type=ty)
From 6a0e4972c2a9dc489a88e2f772cf709d9d303fe3 Mon Sep 17 00:00:00 2001
From: David Li
Date: Mon, 14 Jun 2021 09:51:52 -0400
Subject: [PATCH 2/9] ARROW-12709: [C++] Clarify docs for var_args_join
---
docs/source/cpp/compute.rst | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 0c709ca0f8c..91d21a7277c 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -682,13 +682,13 @@ String joining
This function does the inverse of string splitting.
-+-----------------+-----------+----------------------+----------------+-------------------+-----------------------+---------+
-| Function name | Arity | Input type 1 | Input type 2 | Output type | Options class | Notes |
-+=================+===========+======================+================+===================+=======================+=========+
-| binary_join | Binary | List of string-like | String-like | String-like | | \(1) |
-+-----------------+-----------+----------------------+----------------+-------------------+-----------------------+---------+
-| var_args_join | Varargs | List of string-like | (NA) | String-like | :struct:`JoinOptions` | \(2) |
-+-----------------+-----------+----------------------+----------------+-------------------+-----------------------+---------+
++-----------------+-----------+-----------------------+----------------+-------------------+-----------------------+---------+
+| Function name | Arity | Input type 1 | Input type 2 | Output type | Options class | Notes |
++=================+===========+=======================+================+===================+=======================+=========+
+| binary_join | Binary | List of string-like | String-like | String-like | | \(1) |
++-----------------+-----------+-----------------------+----------------+-------------------+-----------------------+---------+
+| var_args_join | Varargs | String-like (varargs) | String-like | String-like | :struct:`JoinOptions` | \(2) |
++-----------------+-----------+-----------------------+----------------+-------------------+-----------------------+---------+
* \(1) The first input must be an array, while the second can be a scalar or array.
Each list of values in the first input is joined using each second input
@@ -696,7 +696,7 @@ This function does the inverse of string splitting.
output will be null.
* \(2) All arguments are concatenated element-wise, with the last argument treated
- as the separator (scalars are recycled in either case). Null separators emit
+ as the separator (scalars are recycled in either case). Null separators emit
null. If any other argument is null, by default the corresponding output will be
null, but it can instead either be skipped or replaced with a given string.
From 061e3e630465eb9e7ac21c716c6bb7ea9054ad41 Mon Sep 17 00:00:00 2001
From: David Li
Date: Mon, 14 Jun 2021 15:00:46 -0400
Subject: [PATCH 3/9] ARROW-12709: [C++] Rename to binary_join_element_wise
---
.../arrow/compute/kernels/scalar_string.cc | 75 ++++++++--------
.../compute/kernels/scalar_string_test.cc | 89 ++++++++++---------
docs/source/cpp/compute.rst | 18 ++--
docs/source/python/api/compute.rst | 2 +-
python/pyarrow/tests/test_compute.py | 31 ++++---
5 files changed, 112 insertions(+), 103 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc
index dd741eda713..f1825810c56 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string.cc
@@ -3344,39 +3344,16 @@ struct BinaryJoin {
}
};
-const FunctionDoc binary_join_doc(
- "Join a list of strings together with a `separator` to form a single string",
- ("Insert `separator` between `list` elements, and concatenate them.\n"
- "Any null input and any null `list` element emits a null output.\n"),
- {"list", "separator"});
-
-template
-void AddBinaryJoinForListType(ScalarFunction* func) {
- for (const std::shared_ptr& ty : BaseBinaryTypes()) {
- auto exec = GenerateTypeAgnosticVarBinaryBase(*ty);
- auto list_ty = std::make_shared(ty);
- DCHECK_OK(func->AddKernel({InputType(list_ty), InputType(ty)}, ty, exec));
- }
-}
-
-void AddBinaryJoin(FunctionRegistry* registry) {
- auto func =
- std::make_shared("binary_join", Arity::Binary(), &binary_join_doc);
- AddBinaryJoinForListType(func.get());
- AddBinaryJoinForListType(func.get());
- DCHECK_OK(registry->AddFunction(std::move(func)));
-}
-
-using VarArgsJoinState = OptionsWrapper;
+using BinaryJoinElementWiseState = OptionsWrapper;
template
-struct VarArgsJoin {
+struct BinaryJoinElementWise {
using ArrayType = typename TypeTraits::ArrayType;
using BuilderType = typename TypeTraits::BuilderType;
using offset_type = typename Type::offset_type;
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- JoinOptions options = VarArgsJoinState::Get(ctx);
+ JoinOptions options = BinaryJoinElementWiseState::Get(ctx);
// Last argument is the separator (for consistency with binary_join)
if (std::all_of(batch.values.begin(), batch.values.end(),
[](const Datum& d) { return d.is_scalar(); })) {
@@ -3576,7 +3553,13 @@ struct VarArgsJoin {
}
};
-const FunctionDoc var_args_join_doc(
+const FunctionDoc binary_join_doc(
+ "Join a list of strings together with a `separator` to form a single string",
+ ("Insert `separator` between `list` elements, and concatenate them.\n"
+ "Any null input and any null `list` element emits a null output.\n"),
+ {"list", "separator"});
+
+const FunctionDoc binary_join_element_wise_doc(
"Join string arguments into one, using the last argument as the separator",
("Insert the last argument of `strings` between the rest of the elements, "
"and concatenate them.\n"
@@ -3586,16 +3569,35 @@ const FunctionDoc var_args_join_doc(
const auto kDefaultJoinOptions = JoinOptions::Defaults();
-void AddVarArgsJoin(FunctionRegistry* registry) {
- auto func =
- std::make_shared("var_args_join", Arity::VarArgs(/*min_args=*/1),
- &var_args_join_doc, &kDefaultJoinOptions);
- for (const auto& ty : BaseBinaryTypes()) {
- DCHECK_OK(func->AddKernel({InputType(ty)}, ty,
- GenerateTypeAgnosticVarBinaryBase(ty),
- VarArgsJoinState::Init));
+template
+void AddBinaryJoinForListType(ScalarFunction* func) {
+ for (const std::shared_ptr& ty : BaseBinaryTypes()) {
+ auto exec = GenerateTypeAgnosticVarBinaryBase(*ty);
+ auto list_ty = std::make_shared(ty);
+ DCHECK_OK(func->AddKernel({InputType(list_ty), InputType(ty)}, ty, exec));
+ }
+}
+
+void AddBinaryJoin(FunctionRegistry* registry) {
+ {
+ auto func = std::make_shared("binary_join", Arity::Binary(),
+ &binary_join_doc);
+ AddBinaryJoinForListType(func.get());
+ AddBinaryJoinForListType(func.get());
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+ {
+ auto func = std::make_shared(
+ "binary_join_element_wise", Arity::VarArgs(/*min_args=*/1),
+ &binary_join_element_wise_doc, &kDefaultJoinOptions);
+ for (const auto& ty : BaseBinaryTypes()) {
+ DCHECK_OK(
+ func->AddKernel({InputType(ty)}, ty,
+ GenerateTypeAgnosticVarBinaryBase(ty),
+ BinaryJoinElementWiseState::Init));
+ }
+ DCHECK_OK(registry->AddFunction(std::move(func)));
}
- DCHECK_OK(registry->AddFunction(std::move(func)));
}
template class ExecFunctor>
@@ -3906,7 +3908,6 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) {
AddSplit(registry);
AddStrptime(registry);
AddBinaryJoin(registry);
- AddVarArgsJoin(registry);
}
} // namespace internal
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
index d44b6a89b7f..6192e0a5dd7 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
@@ -249,23 +249,25 @@ TYPED_TEST(TestBinaryKernels, CountSubstringIgnoreCase) {
}
#endif
-TYPED_TEST(TestBinaryKernels, VarArgsJoin) {
+TYPED_TEST(TestBinaryKernels, BinaryJoinElementWise) {
const auto ty = this->type();
JoinOptions options;
JoinOptions options_skip(JoinOptions::SKIP);
JoinOptions options_replace(JoinOptions::REPLACE, "X");
// Scalar args, Scalar separator
- this->CheckVarArgsScalar("var_args_join", R"([null])", ty, R"(null)", &options);
- this->CheckVarArgsScalar("var_args_join", R"(["-"])", ty, R"("")", &options);
- this->CheckVarArgsScalar("var_args_join", R"(["a", "-"])", ty, R"("a")", &options);
- this->CheckVarArgsScalar("var_args_join", R"(["a", "b", "-"])", ty, R"("a-b")",
+ this->CheckVarArgsScalar("binary_join_element_wise", R"([null])", ty, R"(null)",
&options);
- this->CheckVarArgsScalar("var_args_join", R"(["a", "b", null])", ty, R"(null)",
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["-"])", ty, R"("")", &options);
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["a", "-"])", ty, R"("a")",
&options);
- this->CheckVarArgsScalar("var_args_join", R"(["a", null, "-"])", ty, R"(null)",
- &options);
- this->CheckVarArgsScalar("var_args_join", R"(["foo", "bar", "baz", "++"])", ty,
- R"("foo++bar++baz")", &options);
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["a", "b", "-"])", ty,
+ R"("a-b")", &options);
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["a", "b", null])", ty,
+ R"(null)", &options);
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["a", null, "-"])", ty,
+ R"(null)", &options);
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["foo", "bar", "baz", "++"])",
+ ty, R"("foo++bar++baz")", &options);
// Scalar args, Array separator
const auto sep = ArrayFromJSON(ty, R"([null, "-", "--"])");
@@ -273,14 +275,15 @@ TYPED_TEST(TestBinaryKernels, VarArgsJoin) {
const auto scalar2 = ScalarFromJSON(ty, R"("bar")");
const auto scalar3 = ScalarFromJSON(ty, R"("")");
const auto scalar_null = ScalarFromJSON(ty, R"(null)");
- this->CheckVarArgs("var_args_join", {sep}, ty, R"([null, "", ""])", &options);
- this->CheckVarArgs("var_args_join", {scalar1, sep}, ty, R"([null, "foo", "foo"])",
+ this->CheckVarArgs("binary_join_element_wise", {sep}, ty, R"([null, "", ""])",
&options);
- this->CheckVarArgs("var_args_join", {scalar1, scalar2, sep}, ty,
+ this->CheckVarArgs("binary_join_element_wise", {scalar1, sep}, ty,
+ R"([null, "foo", "foo"])", &options);
+ this->CheckVarArgs("binary_join_element_wise", {scalar1, scalar2, sep}, ty,
R"([null, "foo-bar", "foo--bar"])", &options);
- this->CheckVarArgs("var_args_join", {scalar1, scalar_null, sep}, ty,
+ this->CheckVarArgs("binary_join_element_wise", {scalar1, scalar_null, sep}, ty,
R"([null, null, null])", &options);
- this->CheckVarArgs("var_args_join", {scalar1, scalar2, scalar3, sep}, ty,
+ this->CheckVarArgs("binary_join_element_wise", {scalar1, scalar2, scalar3, sep}, ty,
R"([null, "foo-bar-", "foo--bar--"])", &options);
// Array args, Scalar separator
@@ -289,60 +292,60 @@ TYPED_TEST(TestBinaryKernels, VarArgsJoin) {
const auto arr1 = ArrayFromJSON(ty, R"([null, "a", "bb", "ccc"])");
const auto arr2 = ArrayFromJSON(ty, R"(["d", null, "e", ""])");
const auto arr3 = ArrayFromJSON(ty, R"(["gg", null, "h", "iii"])");
- this->CheckVarArgs("var_args_join", {arr1, arr2, arr3, scalar_null}, ty,
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, arr3, scalar_null}, ty,
R"([null, null, null, null])", &options);
- this->CheckVarArgs("var_args_join", {arr1, arr2, arr3, sep1}, ty,
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, arr3, sep1}, ty,
R"([null, null, "bb-e-h", "ccc--iii"])", &options);
- this->CheckVarArgs("var_args_join", {arr1, arr2, arr3, sep2}, ty,
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, arr3, sep2}, ty,
R"([null, null, "bb--e--h", "ccc----iii"])", &options);
// Array args, Array separator
const auto sep3 = ArrayFromJSON(ty, R"(["-", "--", null, "---"])");
- this->CheckVarArgs("var_args_join", {arr1, arr2, arr3, sep3}, ty,
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, arr3, sep3}, ty,
R"([null, null, null, "ccc------iii"])", &options);
// Mixed
- this->CheckVarArgs("var_args_join", {arr1, arr2, scalar2, sep3}, ty,
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar2, sep3}, ty,
R"([null, null, null, "ccc------bar"])", &options);
- this->CheckVarArgs("var_args_join", {arr1, arr2, scalar_null, sep3}, ty,
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar_null, sep3}, ty,
R"([null, null, null, null])", &options);
- this->CheckVarArgs("var_args_join", {arr1, arr2, scalar2, sep1}, ty,
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar2, sep1}, ty,
R"([null, null, "bb-e-bar", "ccc--bar"])", &options);
- this->CheckVarArgs("var_args_join", {arr1, arr2, scalar_null, scalar_null}, ty,
- R"([null, null, null, null])", &options);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar_null, scalar_null},
+ ty, R"([null, null, null, null])", &options);
// Skip
- this->CheckVarArgsScalar("var_args_join", R"(["a", null, "b", "-"])", ty, R"("a-b")",
- &options_skip);
- this->CheckVarArgsScalar("var_args_join", R"(["a", null, "b", null])", ty, R"(null)",
- &options_skip);
- this->CheckVarArgs("var_args_join", {arr1, arr2, scalar2, sep3}, ty,
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["a", null, "b", "-"])", ty,
+ R"("a-b")", &options_skip);
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["a", null, "b", null])", ty,
+ R"(null)", &options_skip);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar2, sep3}, ty,
R"(["d-bar", "a--bar", null, "ccc------bar"])", &options_skip);
- this->CheckVarArgs("var_args_join", {arr1, arr2, scalar_null, sep3}, ty,
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar_null, sep3}, ty,
R"(["d", "a", null, "ccc---"])", &options_skip);
- this->CheckVarArgs("var_args_join", {arr1, arr2, scalar2, sep1}, ty,
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar2, sep1}, ty,
R"(["d-bar", "a-bar", "bb-e-bar", "ccc--bar"])", &options_skip);
- this->CheckVarArgs("var_args_join", {arr1, arr2, scalar_null, scalar_null}, ty,
- R"([null, null, null, null])", &options_skip);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar_null, scalar_null},
+ ty, R"([null, null, null, null])", &options_skip);
// Replace
- this->CheckVarArgsScalar("var_args_join", R"(["a", null, "b", "-"])", ty, R"("a-X-b")",
- &options_replace);
- this->CheckVarArgsScalar("var_args_join", R"(["a", null, "b", null])", ty, R"(null)",
- &options_replace);
- this->CheckVarArgs("var_args_join", {arr1, arr2, scalar2, sep3}, ty,
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["a", null, "b", "-"])", ty,
+ R"("a-X-b")", &options_replace);
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["a", null, "b", null])", ty,
+ R"(null)", &options_replace);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar2, sep3}, ty,
R"(["X-d-bar", "a--X--bar", null, "ccc------bar"])",
&options_replace);
- this->CheckVarArgs("var_args_join", {arr1, arr2, scalar_null, sep3}, ty,
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar_null, sep3}, ty,
R"(["X-d-X", "a--X--X", null, "ccc------X"])", &options_replace);
- this->CheckVarArgs("var_args_join", {arr1, arr2, scalar2, sep1}, ty,
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar2, sep1}, ty,
R"(["X-d-bar", "a-X-bar", "bb-e-bar", "ccc--bar"])",
&options_replace);
- this->CheckVarArgs("var_args_join", {arr1, arr2, scalar_null, scalar_null}, ty,
- R"([null, null, null, null])", &options_replace);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar_null, scalar_null},
+ ty, R"([null, null, null, null])", &options_replace);
// Error cases
- ASSERT_RAISES(Invalid, CallFunction("var_args_join", {}, &options));
+ ASSERT_RAISES(Invalid, CallFunction("binary_join_element_wise", {}, &options));
}
template
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 91d21a7277c..7d958b6abe8 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -680,15 +680,15 @@ String component extraction
String joining
~~~~~~~~~~~~~~
-This function does the inverse of string splitting.
-
-+-----------------+-----------+-----------------------+----------------+-------------------+-----------------------+---------+
-| Function name | Arity | Input type 1 | Input type 2 | Output type | Options class | Notes |
-+=================+===========+=======================+================+===================+=======================+=========+
-| binary_join | Binary | List of string-like | String-like | String-like | | \(1) |
-+-----------------+-----------+-----------------------+----------------+-------------------+-----------------------+---------+
-| var_args_join | Varargs | String-like (varargs) | String-like | String-like | :struct:`JoinOptions` | \(2) |
-+-----------------+-----------+-----------------------+----------------+-------------------+-----------------------+---------+
+These functions do the inverse of string splitting.
+
++--------------------------+-----------+-----------------------+----------------+-------------------+-----------------------+---------+
+| Function name | Arity | Input type 1 | Input type 2 | Output type | Options class | Notes |
++==========================+===========+=======================+================+===================+=======================+=========+
+| binary_join | Binary | List of string-like | String-like | String-like | | \(1) |
++--------------------------+-----------+-----------------------+----------------+-------------------+-----------------------+---------+
+| binary_join_element_wise | Varargs | String-like (varargs) | String-like | String-like | :struct:`JoinOptions` | \(2) |
++--------------------------+-----------+-----------------------+----------------+-------------------+-----------------------+---------+
* \(1) The first input must be an array, while the second can be a scalar or array.
Each list of values in the first input is joined using each second input
diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst
index ae231799b56..d178fde1b21 100644
--- a/docs/source/python/api/compute.rst
+++ b/docs/source/python/api/compute.rst
@@ -174,7 +174,7 @@ String Joining
:toctree: ../generated/
binary_join
- var_args_join
+ binary_join_element_wise
String Transforms
-----------------
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index 1f240dbcc6f..d3824791052 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -766,29 +766,34 @@ def test_binary_join():
assert pc.binary_join(ar_list, separator_array).equals(expected)
-def test_var_args_join():
+def test_binary_join_element_wise():
null = pa.scalar(None, type=pa.string())
arrs = [[None, 'a', 'b'], ['c', None, 'd'], [None, '-', '--']]
- assert pc.var_args_join(*arrs).to_pylist() == \
+ assert pc.binary_join_element_wise(*arrs).to_pylist() == \
[None, None, 'b--d']
- assert pc.var_args_join('a', 'b', '-').as_py() == 'a-b'
- assert pc.var_args_join('a', null, '-').as_py() is None
- assert pc.var_args_join('a', 'b', null).as_py() is None
+ assert pc.binary_join_element_wise('a', 'b', '-').as_py() == 'a-b'
+ assert pc.binary_join_element_wise('a', null, '-').as_py() is None
+ assert pc.binary_join_element_wise('a', 'b', null).as_py() is None
skip = pc.JoinOptions('skip')
- assert pc.var_args_join(*arrs, options=skip).to_pylist() == \
+ assert pc.binary_join_element_wise(*arrs, options=skip).to_pylist() == \
[None, 'a', 'b--d']
- assert pc.var_args_join('a', 'b', '-', options=skip).as_py() == 'a-b'
- assert pc.var_args_join('a', null, '-', options=skip).as_py() == 'a'
- assert pc.var_args_join('a', 'b', null, options=skip).as_py() is None
+ assert pc.binary_join_element_wise(
+ 'a', 'b', '-', options=skip).as_py() == 'a-b'
+ assert pc.binary_join_element_wise(
+ 'a', null, '-', options=skip).as_py() == 'a'
+ assert pc.binary_join_element_wise(
+ 'a', 'b', null, options=skip).as_py() is None
replace = pc.JoinOptions('replace', null_replacement='spam')
- assert pc.var_args_join(*arrs, options=replace).to_pylist() == \
+ assert pc.binary_join_element_wise(*arrs, options=replace).to_pylist() == \
[None, 'a-spam', 'b--d']
- assert pc.var_args_join('a', 'b', '-', options=replace).as_py() == 'a-b'
- assert pc.var_args_join(
+ assert pc.binary_join_element_wise(
+ 'a', 'b', '-', options=replace).as_py() == 'a-b'
+ assert pc.binary_join_element_wise(
'a', null, '-', options=replace).as_py() == 'a-spam'
- assert pc.var_args_join('a', 'b', null, options=replace).as_py() is None
+ assert pc.binary_join_element_wise(
+ 'a', 'b', null, options=replace).as_py() is None
@pytest.mark.parametrize(('ty', 'values'), all_array_types)
From bd878f53e56422874288bea549bfcce392fbc173 Mon Sep 17 00:00:00 2001
From: David Li
Date: Mon, 14 Jun 2021 15:08:57 -0400
Subject: [PATCH 4/9] ARROW-12709: [C++] Rename to min/max_element_wise
---
cpp/src/arrow/compute/api_scalar.cc | 8 +-
cpp/src/arrow/compute/api_scalar.h | 4 +-
.../arrow/compute/kernels/scalar_compare.cc | 16 +-
.../compute/kernels/scalar_compare_test.cc | 184 +++++++++---------
docs/source/cpp/compute.rst | 4 +-
docs/source/python/api/compute.rst | 4 +-
python/pyarrow/tests/test_compute.py | 22 +--
7 files changed, 121 insertions(+), 121 deletions(-)
diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc
index dba71456c29..db1cac290cf 100644
--- a/cpp/src/arrow/compute/api_scalar.cc
+++ b/cpp/src/arrow/compute/api_scalar.cc
@@ -63,14 +63,14 @@ SCALAR_ARITHMETIC_BINARY(Multiply, "multiply", "multiply_checked")
SCALAR_ARITHMETIC_BINARY(Divide, "divide", "divide_checked")
SCALAR_ARITHMETIC_BINARY(Power, "power", "power_checked")
-Result ElementWiseMax(const std::vector& args,
+Result MaxElementWise(const std::vector& args,
ElementWiseAggregateOptions options, ExecContext* ctx) {
- return CallFunction("element_wise_max", args, &options, ctx);
+ return CallFunction("max_element_wise", args, &options, ctx);
}
-Result ElementWiseMin(const std::vector& args,
+Result MinElementWise(const std::vector& args,
ElementWiseAggregateOptions options, ExecContext* ctx) {
- return CallFunction("element_wise_min", args, &options, ctx);
+ return CallFunction("min_element_wise", args, &options, ctx);
}
// ----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h
index d6fd7075df2..082876b356b 100644
--- a/cpp/src/arrow/compute/api_scalar.h
+++ b/cpp/src/arrow/compute/api_scalar.h
@@ -306,7 +306,7 @@ Result Power(const Datum& left, const Datum& right,
/// \param[in] ctx the function execution context, optional
/// \return the element-wise maximum
ARROW_EXPORT
-Result ElementWiseMax(
+Result MaxElementWise(
const std::vector& args,
ElementWiseAggregateOptions options = ElementWiseAggregateOptions::Defaults(),
ExecContext* ctx = NULLPTR);
@@ -319,7 +319,7 @@ Result ElementWiseMax(
/// \param[in] ctx the function execution context, optional
/// \return the element-wise minimum
ARROW_EXPORT
-Result ElementWiseMin(
+Result MinElementWise(
const std::vector& args,
ElementWiseAggregateOptions options = ElementWiseAggregateOptions::Defaults(),
ExecContext* ctx = NULLPTR);
diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc
index 6763b6793f3..041c6a282f9 100644
--- a/cpp/src/arrow/compute/kernels/scalar_compare.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc
@@ -467,14 +467,14 @@ const FunctionDoc less_equal_doc{
("A null on either side emits a null comparison result."),
{"x", "y"}};
-const FunctionDoc element_wise_min_doc{
+const FunctionDoc min_element_wise_doc{
"Find the element-wise minimum value",
("Nulls will be ignored (default) or propagated. "
"NaN will be taken over null, but not over any valid float."),
{"*args"},
"ElementWiseAggregateOptions"};
-const FunctionDoc element_wise_max_doc{
+const FunctionDoc max_element_wise_doc{
"Find the element-wise maximum value",
("Nulls will be ignored (default) or propagated. "
"NaN will be taken over null, but not over any valid float."),
@@ -501,13 +501,13 @@ void RegisterScalarComparison(FunctionRegistry* registry) {
// ----------------------------------------------------------------------
// Variadic element-wise functions
- auto element_wise_min =
- MakeScalarMinMax("element_wise_min", &element_wise_min_doc);
- DCHECK_OK(registry->AddFunction(std::move(element_wise_min)));
+ auto min_element_wise =
+ MakeScalarMinMax("min_element_wise", &min_element_wise_doc);
+ DCHECK_OK(registry->AddFunction(std::move(min_element_wise)));
- auto element_wise_max =
- MakeScalarMinMax("element_wise_max", &element_wise_max_doc);
- DCHECK_OK(registry->AddFunction(std::move(element_wise_max)));
+ auto max_element_wise =
+ MakeScalarMinMax("max_element_wise", &max_element_wise_doc);
+ DCHECK_OK(registry->AddFunction(std::move(max_element_wise)));
}
} // namespace internal
diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
index 6318a891d3a..50327e82032 100644
--- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
@@ -729,90 +729,90 @@ TYPED_TEST_SUITE(TestVarArgsCompareNumeric, NumericBasedTypes);
TYPED_TEST_SUITE(TestVarArgsCompareFloating, RealArrowTypes);
TYPED_TEST_SUITE(TestVarArgsCompareParametricTemporal, ParametricTemporalTypes);
-TYPED_TEST(TestVarArgsCompareNumeric, ElementWiseMin) {
- this->AssertNullScalar(ElementWiseMin, {});
- this->AssertNullScalar(ElementWiseMin, {this->scalar("null"), this->scalar("null")});
+TYPED_TEST(TestVarArgsCompareNumeric, MinElementWise) {
+ this->AssertNullScalar(MinElementWise, {});
+ this->AssertNullScalar(MinElementWise, {this->scalar("null"), this->scalar("null")});
- this->Assert(ElementWiseMin, this->scalar("0"), {this->scalar("0")});
- this->Assert(ElementWiseMin, this->scalar("0"),
+ this->Assert(MinElementWise, this->scalar("0"), {this->scalar("0")});
+ this->Assert(MinElementWise, this->scalar("0"),
{this->scalar("2"), this->scalar("0"), this->scalar("1")});
this->Assert(
- ElementWiseMin, this->scalar("0"),
+ MinElementWise, this->scalar("0"),
{this->scalar("2"), this->scalar("0"), this->scalar("1"), this->scalar("null")});
- this->Assert(ElementWiseMin, this->scalar("1"),
+ this->Assert(MinElementWise, this->scalar("1"),
{this->scalar("null"), this->scalar("null"), this->scalar("1"),
this->scalar("null")});
- this->Assert(ElementWiseMin, (this->array("[]")), {this->array("[]")});
- this->Assert(ElementWiseMin, this->array("[1, 2, 3, null]"),
+ this->Assert(MinElementWise, (this->array("[]")), {this->array("[]")});
+ this->Assert(MinElementWise, this->array("[1, 2, 3, null]"),
{this->array("[1, 2, 3, null]")});
- this->Assert(ElementWiseMin, this->array("[1, 2, 2, 2]"),
+ this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"),
{this->array("[1, 2, 3, 4]"), this->scalar("2")});
- this->Assert(ElementWiseMin, this->array("[1, 2, 2, 2]"),
+ this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"),
{this->array("[1, null, 3, 4]"), this->scalar("2")});
- this->Assert(ElementWiseMin, this->array("[1, 2, 2, 2]"),
+ this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"),
{this->array("[1, null, 3, 4]"), this->scalar("2"), this->scalar("4")});
- this->Assert(ElementWiseMin, this->array("[1, 2, 2, 2]"),
+ this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"),
{this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")});
- this->Assert(ElementWiseMin, this->array("[1, 2, 2, 2]"),
+ this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"),
{this->array("[1, 2, 3, 4]"), this->array("[2, 2, 2, 2]")});
- this->Assert(ElementWiseMin, this->array("[1, 2, 2, 2]"),
+ this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"),
{this->array("[1, 2, 3, 4]"), this->array("[2, null, 2, 2]")});
- this->Assert(ElementWiseMin, this->array("[1, 2, 2, 2]"),
+ this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"),
{this->array("[1, null, 3, 4]"), this->array("[2, 2, 2, 2]")});
- this->Assert(ElementWiseMin, this->array("[1, 2, null, 6]"),
+ this->Assert(MinElementWise, this->array("[1, 2, null, 6]"),
{this->array("[1, 2, null, null]"), this->array("[4, null, null, 6]")});
- this->Assert(ElementWiseMin, this->array("[1, 2, null, 6]"),
+ this->Assert(MinElementWise, this->array("[1, 2, null, 6]"),
{this->array("[4, null, null, 6]"), this->array("[1, 2, null, null]")});
- this->Assert(ElementWiseMin, this->array("[1, 2, 3, 4]"),
+ this->Assert(MinElementWise, this->array("[1, 2, 3, 4]"),
{this->array("[1, 2, 3, 4]"), this->array("[null, null, null, null]")});
- this->Assert(ElementWiseMin, this->array("[1, 2, 3, 4]"),
+ this->Assert(MinElementWise, this->array("[1, 2, 3, 4]"),
{this->array("[null, null, null, null]"), this->array("[1, 2, 3, 4]")});
- this->Assert(ElementWiseMin, this->array("[1, 1, 1, 1]"),
+ this->Assert(MinElementWise, this->array("[1, 1, 1, 1]"),
{this->scalar("1"), this->array("[1, 2, 3, 4]")});
- this->Assert(ElementWiseMin, this->array("[1, 1, 1, 1]"),
+ this->Assert(MinElementWise, this->array("[1, 1, 1, 1]"),
{this->scalar("1"), this->array("[null, null, null, null]")});
- this->Assert(ElementWiseMin, this->array("[1, 1, 1, 1]"),
+ this->Assert(MinElementWise, this->array("[1, 1, 1, 1]"),
{this->scalar("null"), this->array("[1, 1, 1, 1]")});
- this->Assert(ElementWiseMin, this->array("[null, null, null, null]"),
+ this->Assert(MinElementWise, this->array("[null, null, null, null]"),
{this->scalar("null"), this->array("[null, null, null, null]")});
// Test null handling
this->element_wise_aggregate_options_.skip_nulls = false;
- this->AssertNullScalar(ElementWiseMin, {this->scalar("null"), this->scalar("null")});
- this->AssertNullScalar(ElementWiseMin, {this->scalar("0"), this->scalar("null")});
+ this->AssertNullScalar(MinElementWise, {this->scalar("null"), this->scalar("null")});
+ this->AssertNullScalar(MinElementWise, {this->scalar("0"), this->scalar("null")});
- this->Assert(ElementWiseMin, this->array("[1, null, 2, 2]"),
+ this->Assert(MinElementWise, this->array("[1, null, 2, 2]"),
{this->array("[1, null, 3, 4]"), this->scalar("2"), this->scalar("4")});
- this->Assert(ElementWiseMin, this->array("[null, null, null, null]"),
+ this->Assert(MinElementWise, this->array("[null, null, null, null]"),
{this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")});
- this->Assert(ElementWiseMin, this->array("[1, null, 2, 2]"),
+ this->Assert(MinElementWise, this->array("[1, null, 2, 2]"),
{this->array("[1, 2, 3, 4]"), this->array("[2, null, 2, 2]")});
- this->Assert(ElementWiseMin, this->array("[null, null, null, null]"),
+ this->Assert(MinElementWise, this->array("[null, null, null, null]"),
{this->scalar("1"), this->array("[null, null, null, null]")});
- this->Assert(ElementWiseMin, this->array("[null, null, null, null]"),
+ this->Assert(MinElementWise, this->array("[null, null, null, null]"),
{this->scalar("null"), this->array("[1, 1, 1, 1]")});
}
-TYPED_TEST(TestVarArgsCompareFloating, ElementWiseMin) {
+TYPED_TEST(TestVarArgsCompareFloating, MinElementWise) {
auto Check = [this](const std::string& expected,
const std::vector& inputs) {
std::vector args;
for (const auto& input : inputs) {
args.emplace_back(this->scalar(input));
}
- this->Assert(ElementWiseMin, this->scalar(expected), args);
+ this->Assert(MinElementWise, this->scalar(expected), args);
args.clear();
for (const auto& input : inputs) {
args.emplace_back(this->array("[" + input + "]"));
}
- this->Assert(ElementWiseMin, this->array("[" + expected + "]"), args);
+ this->Assert(MinElementWise, this->array("[" + expected + "]"), args);
};
Check("-0.0", {"0.0", "-0.0"});
Check("-0.0", {"1.0", "-0.0", "0.0"});
@@ -828,111 +828,111 @@ TYPED_TEST(TestVarArgsCompareFloating, ElementWiseMin) {
Check("-Inf", {"0", "-Inf"});
}
-TYPED_TEST(TestVarArgsCompareParametricTemporal, ElementWiseMin) {
+TYPED_TEST(TestVarArgsCompareParametricTemporal, MinElementWise) {
// Temporal kernel is implemented with numeric kernel underneath
- this->AssertNullScalar(ElementWiseMin, {});
- this->AssertNullScalar(ElementWiseMin, {this->scalar("null"), this->scalar("null")});
+ this->AssertNullScalar(MinElementWise, {});
+ this->AssertNullScalar(MinElementWise, {this->scalar("null"), this->scalar("null")});
- this->Assert(ElementWiseMin, this->scalar("0"), {this->scalar("0")});
- this->Assert(ElementWiseMin, this->scalar("0"), {this->scalar("2"), this->scalar("0")});
- this->Assert(ElementWiseMin, this->scalar("0"),
+ this->Assert(MinElementWise, this->scalar("0"), {this->scalar("0")});
+ this->Assert(MinElementWise, this->scalar("0"), {this->scalar("2"), this->scalar("0")});
+ this->Assert(MinElementWise, this->scalar("0"),
{this->scalar("0"), this->scalar("null")});
- this->Assert(ElementWiseMin, (this->array("[]")), {this->array("[]")});
- this->Assert(ElementWiseMin, this->array("[1, 2, 3, null]"),
+ this->Assert(MinElementWise, (this->array("[]")), {this->array("[]")});
+ this->Assert(MinElementWise, this->array("[1, 2, 3, null]"),
{this->array("[1, 2, 3, null]")});
- this->Assert(ElementWiseMin, this->array("[1, 2, 2, 2]"),
+ this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"),
{this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")});
- this->Assert(ElementWiseMin, this->array("[1, 2, 3, 2]"),
+ this->Assert(MinElementWise, this->array("[1, 2, 3, 2]"),
{this->array("[1, null, 3, 4]"), this->array("[2, 2, null, 2]")});
}
-TYPED_TEST(TestVarArgsCompareNumeric, ElementWiseMax) {
- this->AssertNullScalar(ElementWiseMax, {});
- this->AssertNullScalar(ElementWiseMax, {this->scalar("null"), this->scalar("null")});
+TYPED_TEST(TestVarArgsCompareNumeric, MaxElementWise) {
+ this->AssertNullScalar(MaxElementWise, {});
+ this->AssertNullScalar(MaxElementWise, {this->scalar("null"), this->scalar("null")});
- this->Assert(ElementWiseMax, this->scalar("0"), {this->scalar("0")});
- this->Assert(ElementWiseMax, this->scalar("2"),
+ this->Assert(MaxElementWise, this->scalar("0"), {this->scalar("0")});
+ this->Assert(MaxElementWise, this->scalar("2"),
{this->scalar("2"), this->scalar("0"), this->scalar("1")});
this->Assert(
- ElementWiseMax, this->scalar("2"),
+ MaxElementWise, this->scalar("2"),
{this->scalar("2"), this->scalar("0"), this->scalar("1"), this->scalar("null")});
- this->Assert(ElementWiseMax, this->scalar("1"),
+ this->Assert(MaxElementWise, this->scalar("1"),
{this->scalar("null"), this->scalar("null"), this->scalar("1"),
this->scalar("null")});
- this->Assert(ElementWiseMax, (this->array("[]")), {this->array("[]")});
- this->Assert(ElementWiseMax, this->array("[1, 2, 3, null]"),
+ this->Assert(MaxElementWise, (this->array("[]")), {this->array("[]")});
+ this->Assert(MaxElementWise, this->array("[1, 2, 3, null]"),
{this->array("[1, 2, 3, null]")});
- this->Assert(ElementWiseMax, this->array("[2, 2, 3, 4]"),
+ this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"),
{this->array("[1, 2, 3, 4]"), this->scalar("2")});
- this->Assert(ElementWiseMax, this->array("[2, 2, 3, 4]"),
+ this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"),
{this->array("[1, null, 3, 4]"), this->scalar("2")});
- this->Assert(ElementWiseMax, this->array("[4, 4, 4, 4]"),
+ this->Assert(MaxElementWise, this->array("[4, 4, 4, 4]"),
{this->array("[1, null, 3, 4]"), this->scalar("2"), this->scalar("4")});
- this->Assert(ElementWiseMax, this->array("[2, 2, 3, 4]"),
+ this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"),
{this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")});
- this->Assert(ElementWiseMax, this->array("[2, 2, 3, 4]"),
+ this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"),
{this->array("[1, 2, 3, 4]"), this->array("[2, 2, 2, 2]")});
- this->Assert(ElementWiseMax, this->array("[2, 2, 3, 4]"),
+ this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"),
{this->array("[1, 2, 3, 4]"), this->array("[2, null, 2, 2]")});
- this->Assert(ElementWiseMax, this->array("[2, 2, 3, 4]"),
+ this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"),
{this->array("[1, null, 3, 4]"), this->array("[2, 2, 2, 2]")});
- this->Assert(ElementWiseMax, this->array("[4, 2, null, 6]"),
+ this->Assert(MaxElementWise, this->array("[4, 2, null, 6]"),
{this->array("[1, 2, null, null]"), this->array("[4, null, null, 6]")});
- this->Assert(ElementWiseMax, this->array("[4, 2, null, 6]"),
+ this->Assert(MaxElementWise, this->array("[4, 2, null, 6]"),
{this->array("[4, null, null, 6]"), this->array("[1, 2, null, null]")});
- this->Assert(ElementWiseMax, this->array("[1, 2, 3, 4]"),
+ this->Assert(MaxElementWise, this->array("[1, 2, 3, 4]"),
{this->array("[1, 2, 3, 4]"), this->array("[null, null, null, null]")});
- this->Assert(ElementWiseMax, this->array("[1, 2, 3, 4]"),
+ this->Assert(MaxElementWise, this->array("[1, 2, 3, 4]"),
{this->array("[null, null, null, null]"), this->array("[1, 2, 3, 4]")});
- this->Assert(ElementWiseMax, this->array("[1, 2, 3, 4]"),
+ this->Assert(MaxElementWise, this->array("[1, 2, 3, 4]"),
{this->scalar("1"), this->array("[1, 2, 3, 4]")});
- this->Assert(ElementWiseMax, this->array("[1, 1, 1, 1]"),
+ this->Assert(MaxElementWise, this->array("[1, 1, 1, 1]"),
{this->scalar("1"), this->array("[null, null, null, null]")});
- this->Assert(ElementWiseMax, this->array("[1, 1, 1, 1]"),
+ this->Assert(MaxElementWise, this->array("[1, 1, 1, 1]"),
{this->scalar("null"), this->array("[1, 1, 1, 1]")});
- this->Assert(ElementWiseMax, this->array("[null, null, null, null]"),
+ this->Assert(MaxElementWise, this->array("[null, null, null, null]"),
{this->scalar("null"), this->array("[null, null, null, null]")});
// Test null handling
this->element_wise_aggregate_options_.skip_nulls = false;
- this->AssertNullScalar(ElementWiseMax, {this->scalar("null"), this->scalar("null")});
- this->AssertNullScalar(ElementWiseMax, {this->scalar("0"), this->scalar("null")});
+ this->AssertNullScalar(MaxElementWise, {this->scalar("null"), this->scalar("null")});
+ this->AssertNullScalar(MaxElementWise, {this->scalar("0"), this->scalar("null")});
- this->Assert(ElementWiseMax, this->array("[4, null, 4, 4]"),
+ this->Assert(MaxElementWise, this->array("[4, null, 4, 4]"),
{this->array("[1, null, 3, 4]"), this->scalar("2"), this->scalar("4")});
- this->Assert(ElementWiseMax, this->array("[null, null, null, null]"),
+ this->Assert(MaxElementWise, this->array("[null, null, null, null]"),
{this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")});
- this->Assert(ElementWiseMax, this->array("[2, null, 3, 4]"),
+ this->Assert(MaxElementWise, this->array("[2, null, 3, 4]"),
{this->array("[1, 2, 3, 4]"), this->array("[2, null, 2, 2]")});
- this->Assert(ElementWiseMax, this->array("[null, null, null, null]"),
+ this->Assert(MaxElementWise, this->array("[null, null, null, null]"),
{this->scalar("1"), this->array("[null, null, null, null]")});
- this->Assert(ElementWiseMax, this->array("[null, null, null, null]"),
+ this->Assert(MaxElementWise, this->array("[null, null, null, null]"),
{this->scalar("null"), this->array("[1, 1, 1, 1]")});
}
-TYPED_TEST(TestVarArgsCompareFloating, ElementWiseMax) {
+TYPED_TEST(TestVarArgsCompareFloating, MaxElementWise) {
auto Check = [this](const std::string& expected,
const std::vector& inputs) {
std::vector args;
for (const auto& input : inputs) {
args.emplace_back(this->scalar(input));
}
- this->Assert(ElementWiseMax, this->scalar(expected), args);
+ this->Assert(MaxElementWise, this->scalar(expected), args);
args.clear();
for (const auto& input : inputs) {
args.emplace_back(this->array("[" + input + "]"));
}
- this->Assert(ElementWiseMax, this->array("[" + expected + "]"), args);
+ this->Assert(MaxElementWise, this->array("[" + expected + "]"), args);
};
Check("0.0", {"0.0", "-0.0"});
Check("1.0", {"1.0", "-0.0", "0.0"});
@@ -948,34 +948,34 @@ TYPED_TEST(TestVarArgsCompareFloating, ElementWiseMax) {
Check("0", {"0", "-Inf"});
}
-TYPED_TEST(TestVarArgsCompareParametricTemporal, ElementWiseMax) {
+TYPED_TEST(TestVarArgsCompareParametricTemporal, MaxElementWise) {
// Temporal kernel is implemented with numeric kernel underneath
- this->AssertNullScalar(ElementWiseMax, {});
- this->AssertNullScalar(ElementWiseMax, {this->scalar("null"), this->scalar("null")});
+ this->AssertNullScalar(MaxElementWise, {});
+ this->AssertNullScalar(MaxElementWise, {this->scalar("null"), this->scalar("null")});
- this->Assert(ElementWiseMax, this->scalar("0"), {this->scalar("0")});
- this->Assert(ElementWiseMax, this->scalar("2"), {this->scalar("2"), this->scalar("0")});
- this->Assert(ElementWiseMax, this->scalar("0"),
+ this->Assert(MaxElementWise, this->scalar("0"), {this->scalar("0")});
+ this->Assert(MaxElementWise, this->scalar("2"), {this->scalar("2"), this->scalar("0")});
+ this->Assert(MaxElementWise, this->scalar("0"),
{this->scalar("0"), this->scalar("null")});
- this->Assert(ElementWiseMax, (this->array("[]")), {this->array("[]")});
- this->Assert(ElementWiseMax, this->array("[1, 2, 3, null]"),
+ this->Assert(MaxElementWise, (this->array("[]")), {this->array("[]")});
+ this->Assert(MaxElementWise, this->array("[1, 2, 3, null]"),
{this->array("[1, 2, 3, null]")});
- this->Assert(ElementWiseMax, this->array("[2, 2, 3, 4]"),
+ this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"),
{this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")});
- this->Assert(ElementWiseMax, this->array("[2, 2, 3, 4]"),
+ this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"),
{this->array("[1, null, 3, 4]"), this->array("[2, 2, null, 2]")});
}
-TEST(TestElementWiseMaxElementWiseMin, CommonTimestamp) {
+TEST(TestMaxElementWiseMinElementWise, CommonTimestamp) {
{
auto t1 = std::make_shared(TimeUnit::SECOND);
auto t2 = std::make_shared(TimeUnit::MILLI);
auto expected = MakeScalar(t2, 1000).ValueOrDie();
ASSERT_OK_AND_ASSIGN(auto actual,
- ElementWiseMin({Datum(MakeScalar(t1, 1).ValueOrDie()),
+ MinElementWise({Datum(MakeScalar(t1, 1).ValueOrDie()),
Datum(MakeScalar(t2, 12000).ValueOrDie())}));
AssertScalarsEqual(*expected, *actual.scalar(), /*verbose=*/true);
}
@@ -984,7 +984,7 @@ TEST(TestElementWiseMaxElementWiseMin, CommonTimestamp) {
auto t2 = std::make_shared(TimeUnit::SECOND);
auto expected = MakeScalar(t2, 86401).ValueOrDie();
ASSERT_OK_AND_ASSIGN(auto actual,
- ElementWiseMax({Datum(MakeScalar(t1, 1).ValueOrDie()),
+ MaxElementWise({Datum(MakeScalar(t1, 1).ValueOrDie()),
Datum(MakeScalar(t2, 86401).ValueOrDie())}));
AssertScalarsEqual(*expected, *actual.scalar(), /*verbose=*/true);
}
@@ -994,7 +994,7 @@ TEST(TestElementWiseMaxElementWiseMin, CommonTimestamp) {
auto t3 = std::make_shared(TimeUnit::SECOND);
auto expected = MakeScalar(t3, 86400).ValueOrDie();
ASSERT_OK_AND_ASSIGN(
- auto actual, ElementWiseMin({Datum(MakeScalar(t1, 1).ValueOrDie()),
+ auto actual, MinElementWise({Datum(MakeScalar(t1, 1).ValueOrDie()),
Datum(MakeScalar(t2, 2 * 86400000).ValueOrDie())}));
AssertScalarsEqual(*expected, *actual.scalar(), /*verbose=*/true);
}
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 7d958b6abe8..dfdd64d19c6 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -318,8 +318,8 @@ expanded for the purposes of comparison.
+--------------------------+------------+---------------------------------------------+---------------------+---------------------------------------+-------+
| Function names | Arity | Input types | Output type | Options class | Notes |
+==========================+============+=============================================+=====================+=======================================+=======+
-| element_wise_max, | Varargs | Numeric and Temporal | Numeric or Temporal | :struct:`ElementWiseAggregateOptions` | \(1) |
-| element_wise_min | | | | | |
+| max_element_wise, | Varargs | Numeric and Temporal | Numeric or Temporal | :struct:`ElementWiseAggregateOptions` | \(1) |
+| min_element_wise | | | | | |
+--------------------------+------------+---------------------------------------------+---------------------+---------------------------------------+-------+
* \(1) By default, nulls are skipped (but the kernel can be configured to propagate nulls).
diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst
index d178fde1b21..80fcb2078f1 100644
--- a/docs/source/python/api/compute.rst
+++ b/docs/source/python/api/compute.rst
@@ -80,8 +80,8 @@ These functions take any number of arguments of a numeric or temporal type.
.. autosummary::
:toctree: ../generated/
- element_wise_max
- element_wise_min
+ max_element_wise
+ min_element_wise
Logical Functions
-----------------
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index d3824791052..efe2e6be2f8 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -1467,35 +1467,35 @@ def test_fill_null_segfault():
assert result == pa.array([0], pa.int8())
-def test_elementwise_min_max():
+def test_min_max_element_wise():
arr1 = pa.array([1, 2, 3])
arr2 = pa.array([3, 1, 2])
arr3 = pa.array([2, 3, None])
- result = pc.element_wise_max(arr1, arr2)
+ result = pc.max_element_wise(arr1, arr2)
assert result == pa.array([3, 2, 3])
- result = pc.element_wise_min(arr1, arr2)
+ result = pc.min_element_wise(arr1, arr2)
assert result == pa.array([1, 1, 2])
- result = pc.element_wise_max(arr1, arr2, arr3)
+ result = pc.max_element_wise(arr1, arr2, arr3)
assert result == pa.array([3, 3, 3])
- result = pc.element_wise_min(arr1, arr2, arr3)
+ result = pc.min_element_wise(arr1, arr2, arr3)
assert result == pa.array([1, 1, 2])
# with specifying the option
- result = pc.element_wise_max(arr1, arr3, skip_nulls=True)
+ result = pc.max_element_wise(arr1, arr3, skip_nulls=True)
assert result == pa.array([2, 3, 3])
- result = pc.element_wise_min(arr1, arr3, skip_nulls=True)
+ result = pc.min_element_wise(arr1, arr3, skip_nulls=True)
assert result == pa.array([1, 2, 3])
- result = pc.element_wise_max(
+ result = pc.max_element_wise(
arr1, arr3, options=pc.ElementWiseAggregateOptions())
assert result == pa.array([2, 3, 3])
- result = pc.element_wise_min(
+ result = pc.min_element_wise(
arr1, arr3, options=pc.ElementWiseAggregateOptions())
assert result == pa.array([1, 2, 3])
# not skipping nulls
- result = pc.element_wise_max(arr1, arr3, skip_nulls=False)
+ result = pc.max_element_wise(arr1, arr3, skip_nulls=False)
assert result == pa.array([2, 3, None])
- result = pc.element_wise_min(arr1, arr3, skip_nulls=False)
+ result = pc.min_element_wise(arr1, arr3, skip_nulls=False)
assert result == pa.array([1, 2, None])
From c3f20e586e509de3e6f17c12748b88cd5b80ce1a Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 15 Jun 2021 09:12:57 -0400
Subject: [PATCH 5/9] ARROW-12709: [C++] Add benchmark and debug checks
---
.../arrow/compute/kernels/scalar_string.cc | 3 ++
.../kernels/scalar_string_benchmark.cc | 44 +++++++++++++++++++
2 files changed, 47 insertions(+)
diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc
index f1825810c56..e8d0f3058de 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string.cc
@@ -3408,6 +3408,7 @@ struct BinaryJoinElementWise {
buf = std::copy(s.begin(), s.end(), buf);
}
output->is_valid = true;
+ DCHECK_EQ(final_size, buf - output->value->mutable_data());
return Status::OK();
}
@@ -3492,6 +3493,8 @@ struct BinaryJoinElementWise {
*out = *string_array->data();
out->mutable_array()->type = batch[0].type();
DCHECK_EQ(batch.length, out->array()->length);
+ DCHECK_EQ(final_size,
+ checked_cast(*string_array).total_values_length());
return Status::OK();
}
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc
index 606e774451c..943ac30d308 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc
@@ -169,6 +169,48 @@ static void BinaryJoinArrayArray(benchmark::State& state) {
});
}
+static void BinaryJoinElementWise(benchmark::State& state,
+ SeparatorFactory make_separator) {
+ // Unfortunately benchmark is not 1:1 with BinaryJoin since BinaryJoin can join a
+ // varying number of inputs per output
+ const int64_t n_strings = 1000;
+ const int64_t n_lists = 10;
+ const double null_probability = 0.02;
+
+ random::RandomArrayGenerator rng(kSeed);
+
+ DatumVector args;
+ ArrayVector strings;
+ int64_t total_values_length = 0;
+ for (int i = 0; i < n_lists; i++) {
+ auto arr =
+ rng.String(n_strings, /*min_length=*/5, /*max_length=*/20, null_probability);
+ strings.push_back(arr);
+ args.emplace_back(arr);
+ total_values_length += checked_cast(*arr).total_values_length();
+ }
+ auto separator = make_separator(n_strings, null_probability);
+ args.emplace_back(separator);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(CallFunction("binary_join_element_wise", args));
+ }
+ state.SetBytesProcessed(state.iterations() * total_values_length);
+}
+
+static void BinaryJoinElementWiseArrayScalar(benchmark::State& state) {
+ BinaryJoinElementWise(state, [](int64_t n, double null_probability) -> Datum {
+ return ScalarFromJSON(utf8(), R"("--")");
+ });
+}
+
+static void BinaryJoinElementWiseArrayArray(benchmark::State& state) {
+ BinaryJoinElementWise(state, [](int64_t n, double null_probability) -> Datum {
+ random::RandomArrayGenerator rng(kSeed + 1);
+ return rng.String(n, /*min_length=*/0, /*max_length=*/4, null_probability);
+ });
+}
+
BENCHMARK(AsciiLower);
BENCHMARK(AsciiUpper);
BENCHMARK(IsAlphaNumericAscii);
@@ -192,6 +234,8 @@ BENCHMARK(TrimManyUtf8);
BENCHMARK(BinaryJoinArrayScalar);
BENCHMARK(BinaryJoinArrayArray);
+BENCHMARK(BinaryJoinElementWiseArrayScalar);
+BENCHMARK(BinaryJoinElementWiseArrayArray);
} // namespace compute
} // namespace arrow
From 1a1ff7c43f69a9df8e6c6e742a2927f8f4d839fb Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 15 Jun 2021 09:47:47 -0400
Subject: [PATCH 6/9] ARROW-12709: [C++] Parameterize benchmark
---
cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc
index 943ac30d308..e2bd27369b4 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc
@@ -174,7 +174,7 @@ static void BinaryJoinElementWise(benchmark::State& state,
// Unfortunately benchmark is not 1:1 with BinaryJoin since BinaryJoin can join a
// varying number of inputs per output
const int64_t n_strings = 1000;
- const int64_t n_lists = 10;
+ const int64_t n_lists = state.range(0);
const double null_probability = 0.02;
random::RandomArrayGenerator rng(kSeed);
@@ -234,8 +234,8 @@ BENCHMARK(TrimManyUtf8);
BENCHMARK(BinaryJoinArrayScalar);
BENCHMARK(BinaryJoinArrayArray);
-BENCHMARK(BinaryJoinElementWiseArrayScalar);
-BENCHMARK(BinaryJoinElementWiseArrayArray);
+BENCHMARK(BinaryJoinElementWiseArrayScalar)->RangeMultiplier(8)->Range(2, 128);
+BENCHMARK(BinaryJoinElementWiseArrayArray)->RangeMultiplier(8)->Range(2, 128);
} // namespace compute
} // namespace arrow
From d85e36211c09111b2b3908328d88586e0204e303 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 15 Jun 2021 12:52:15 -0400
Subject: [PATCH 7/9] ARROW-12709: [C++] Improve few-columns case at expense of
many-columns
---
.../arrow/compute/kernels/scalar_string.cc | 53 ++++++++-----------
1 file changed, 23 insertions(+), 30 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc
index e8d0f3058de..3f63bf2c405 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string.cc
@@ -3424,27 +3424,35 @@ struct BinaryJoinElementWise {
RETURN_NOT_OK(builder.Reserve(batch.length));
RETURN_NOT_OK(builder.ReserveData(final_size));
- std::vector valid_cols(batch.values.size());
+ std::vector valid_cols(batch.values.size());
for (size_t row = 0; row < static_cast(batch.length); row++) {
size_t num_valid = 0; // Not counting separator
for (size_t col = 0; col < batch.values.size(); col++) {
- bool valid = false;
if (batch[col].is_scalar()) {
- valid = batch[col].scalar()->is_valid;
+ const auto& scalar = *batch[col].scalar();
+ if (scalar.is_valid) {
+ valid_cols[col] = UnboxScalar::Unbox(scalar);
+ if (col < batch.values.size() - 1) num_valid++;
+ } else {
+ valid_cols[col] = util::string_view();
+ }
} else {
const ArrayData& array = *batch[col].array();
- valid = !array.MayHaveNulls() ||
- BitUtil::GetBit(array.buffers[0]->data(), array.offset + row);
- }
- if (valid) {
- valid_cols[col] = &batch[col];
- if (col < batch.values.size() - 1) num_valid++;
- } else {
- valid_cols[col] = nullptr;
+ if (!array.MayHaveNulls() ||
+ BitUtil::GetBit(array.buffers[0]->data(), array.offset + row)) {
+ const offset_type* offsets = array.GetValues(1);
+ const uint8_t* data = array.GetValues(2, /*absolute_offset=*/0);
+ const int64_t length = offsets[row + 1] - offsets[row];
+ valid_cols[col] = util::string_view(
+ reinterpret_cast(data + offsets[row]), length);
+ if (col < batch.values.size() - 1) num_valid++;
+ } else {
+ valid_cols[col] = util::string_view();
+ }
}
}
- if (!valid_cols.back()) {
+ if (!valid_cols.back().data()) {
// Separator is null
builder.UnsafeAppendNull();
continue;
@@ -3459,12 +3467,11 @@ struct BinaryJoinElementWise {
continue;
}
}
- const auto separator = Lookup(*valid_cols.back(), row);
+ const auto separator = valid_cols.back();
bool first = true;
for (size_t col = 0; col < batch.values.size() - 1; col++) {
- const Datum* datum = valid_cols[col];
- util::string_view value;
- if (!datum) {
+ util::string_view value = valid_cols[col];
+ if (!value.data()) {
switch (options.null_handling) {
case JoinOptions::EMIT_NULL:
DCHECK(false) << "unreachable";
@@ -3475,8 +3482,6 @@ struct BinaryJoinElementWise {
value = options.null_replacement;
break;
}
- } else {
- value = Lookup(*datum, row);
}
if (first) {
builder.UnsafeAppend(value);
@@ -3498,18 +3503,6 @@ struct BinaryJoinElementWise {
return Status::OK();
}
- // Unbox a scalar or the given element of an array.
- static util::string_view Lookup(const Datum& datum, size_t row) {
- if (datum.is_scalar()) {
- return UnboxScalar::Unbox(*datum.scalar());
- }
- const ArrayData& array = *datum.array();
- const offset_type* offsets = array.GetValues(1);
- const uint8_t* data = array.GetValues(2, /*absolute_offset=*/0);
- const int64_t length = offsets[row + 1] - offsets[row];
- return util::string_view(reinterpret_cast(data + offsets[row]), length);
- }
-
// Compute the length of the output for the given position, or -1 if it would be null.
static int64_t CalculateRowSize(const JoinOptions& options, const ExecBatch& batch,
const int64_t index) {
From 3e960d049bee48f833a8f31d032b52904940b54a Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 15 Jun 2021 13:06:59 -0400
Subject: [PATCH 8/9] ARROW-12709: [C++] Run benchmark with more elements
---
cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc
index e2bd27369b4..1978e850155 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc
@@ -173,7 +173,7 @@ static void BinaryJoinElementWise(benchmark::State& state,
SeparatorFactory make_separator) {
// Unfortunately benchmark is not 1:1 with BinaryJoin since BinaryJoin can join a
// varying number of inputs per output
- const int64_t n_strings = 1000;
+ const int64_t n_strings = 65536;
const int64_t n_lists = state.range(0);
const double null_probability = 0.02;
From 491d41aded16259feb7a514d534bb2c9e59e87f5 Mon Sep 17 00:00:00 2001
From: Antoine Pitrou
Date: Wed, 16 Jun 2021 15:20:24 +0200
Subject: [PATCH 9/9] Use same number of rows as in binary_join benchmark
---
.../arrow/compute/kernels/scalar_string_benchmark.cc | 11 +++++------
1 file changed, 5 insertions(+), 6 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc
index 1978e850155..ddc3a56f00f 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc
@@ -173,8 +173,8 @@ static void BinaryJoinElementWise(benchmark::State& state,
SeparatorFactory make_separator) {
// Unfortunately benchmark is not 1:1 with BinaryJoin since BinaryJoin can join a
// varying number of inputs per output
- const int64_t n_strings = 65536;
- const int64_t n_lists = state.range(0);
+ const int64_t n_rows = 10000;
+ const int64_t n_cols = state.range(0);
const double null_probability = 0.02;
random::RandomArrayGenerator rng(kSeed);
@@ -182,14 +182,13 @@ static void BinaryJoinElementWise(benchmark::State& state,
DatumVector args;
ArrayVector strings;
int64_t total_values_length = 0;
- for (int i = 0; i < n_lists; i++) {
- auto arr =
- rng.String(n_strings, /*min_length=*/5, /*max_length=*/20, null_probability);
+ for (int i = 0; i < n_cols; i++) {
+ auto arr = rng.String(n_rows, /*min_length=*/5, /*max_length=*/20, null_probability);
strings.push_back(arr);
args.emplace_back(arr);
total_values_length += checked_cast(*arr).total_values_length();
}
- auto separator = make_separator(n_strings, null_probability);
+ auto separator = make_separator(n_rows, null_probability);
args.emplace_back(separator);
for (auto _ : state) {