diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index dba71456c29..db1cac290cf 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -63,14 +63,14 @@ SCALAR_ARITHMETIC_BINARY(Multiply, "multiply", "multiply_checked") SCALAR_ARITHMETIC_BINARY(Divide, "divide", "divide_checked") SCALAR_ARITHMETIC_BINARY(Power, "power", "power_checked") -Result ElementWiseMax(const std::vector& args, +Result MaxElementWise(const std::vector& args, ElementWiseAggregateOptions options, ExecContext* ctx) { - return CallFunction("element_wise_max", args, &options, ctx); + return CallFunction("max_element_wise", args, &options, ctx); } -Result ElementWiseMin(const std::vector& args, +Result MinElementWise(const std::vector& args, ElementWiseAggregateOptions options, ExecContext* ctx) { - return CallFunction("element_wise_min", args, &options, ctx); + return CallFunction("min_element_wise", args, &options, ctx); } // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 6e9a9340f2c..082876b356b 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -48,6 +48,25 @@ struct ARROW_EXPORT ElementWiseAggregateOptions : public FunctionOptions { bool skip_nulls; }; +/// Options for var_args_join. +struct ARROW_EXPORT JoinOptions : public FunctionOptions { + /// How to handle null values. (A null separator always results in a null output.) + enum NullHandlingBehavior { + /// A null in any input results in a null in the output. + EMIT_NULL, + /// Nulls in inputs are skipped. + SKIP, + /// Nulls in inputs are replaced with the replacement string. + REPLACE, + }; + explicit JoinOptions(NullHandlingBehavior null_handling = EMIT_NULL, + std::string null_replacement = "") + : null_handling(null_handling), null_replacement(std::move(null_replacement)) {} + static JoinOptions Defaults() { return JoinOptions(); } + NullHandlingBehavior null_handling; + std::string null_replacement; +}; + struct ARROW_EXPORT MatchSubstringOptions : public FunctionOptions { explicit MatchSubstringOptions(std::string pattern, bool ignore_case = false) : pattern(std::move(pattern)), ignore_case(ignore_case) {} @@ -287,7 +306,7 @@ Result Power(const Datum& left, const Datum& right, /// \param[in] ctx the function execution context, optional /// \return the element-wise maximum ARROW_EXPORT -Result ElementWiseMax( +Result MaxElementWise( const std::vector& args, ElementWiseAggregateOptions options = ElementWiseAggregateOptions::Defaults(), ExecContext* ctx = NULLPTR); @@ -300,7 +319,7 @@ Result ElementWiseMax( /// \param[in] ctx the function execution context, optional /// \return the element-wise minimum ARROW_EXPORT -Result ElementWiseMin( +Result MinElementWise( const std::vector& args, ElementWiseAggregateOptions options = ElementWiseAggregateOptions::Defaults(), ExecContext* ctx = NULLPTR); diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index 6763b6793f3..041c6a282f9 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -467,14 +467,14 @@ const FunctionDoc less_equal_doc{ ("A null on either side emits a null comparison result."), {"x", "y"}}; -const FunctionDoc element_wise_min_doc{ +const FunctionDoc min_element_wise_doc{ "Find the element-wise minimum value", ("Nulls will be ignored (default) or propagated. " "NaN will be taken over null, but not over any valid float."), {"*args"}, "ElementWiseAggregateOptions"}; -const FunctionDoc element_wise_max_doc{ +const FunctionDoc max_element_wise_doc{ "Find the element-wise maximum value", ("Nulls will be ignored (default) or propagated. " "NaN will be taken over null, but not over any valid float."), @@ -501,13 +501,13 @@ void RegisterScalarComparison(FunctionRegistry* registry) { // ---------------------------------------------------------------------- // Variadic element-wise functions - auto element_wise_min = - MakeScalarMinMax("element_wise_min", &element_wise_min_doc); - DCHECK_OK(registry->AddFunction(std::move(element_wise_min))); + auto min_element_wise = + MakeScalarMinMax("min_element_wise", &min_element_wise_doc); + DCHECK_OK(registry->AddFunction(std::move(min_element_wise))); - auto element_wise_max = - MakeScalarMinMax("element_wise_max", &element_wise_max_doc); - DCHECK_OK(registry->AddFunction(std::move(element_wise_max))); + auto max_element_wise = + MakeScalarMinMax("max_element_wise", &max_element_wise_doc); + DCHECK_OK(registry->AddFunction(std::move(max_element_wise))); } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index 6318a891d3a..50327e82032 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -729,90 +729,90 @@ TYPED_TEST_SUITE(TestVarArgsCompareNumeric, NumericBasedTypes); TYPED_TEST_SUITE(TestVarArgsCompareFloating, RealArrowTypes); TYPED_TEST_SUITE(TestVarArgsCompareParametricTemporal, ParametricTemporalTypes); -TYPED_TEST(TestVarArgsCompareNumeric, ElementWiseMin) { - this->AssertNullScalar(ElementWiseMin, {}); - this->AssertNullScalar(ElementWiseMin, {this->scalar("null"), this->scalar("null")}); +TYPED_TEST(TestVarArgsCompareNumeric, MinElementWise) { + this->AssertNullScalar(MinElementWise, {}); + this->AssertNullScalar(MinElementWise, {this->scalar("null"), this->scalar("null")}); - this->Assert(ElementWiseMin, this->scalar("0"), {this->scalar("0")}); - this->Assert(ElementWiseMin, this->scalar("0"), + this->Assert(MinElementWise, this->scalar("0"), {this->scalar("0")}); + this->Assert(MinElementWise, this->scalar("0"), {this->scalar("2"), this->scalar("0"), this->scalar("1")}); this->Assert( - ElementWiseMin, this->scalar("0"), + MinElementWise, this->scalar("0"), {this->scalar("2"), this->scalar("0"), this->scalar("1"), this->scalar("null")}); - this->Assert(ElementWiseMin, this->scalar("1"), + this->Assert(MinElementWise, this->scalar("1"), {this->scalar("null"), this->scalar("null"), this->scalar("1"), this->scalar("null")}); - this->Assert(ElementWiseMin, (this->array("[]")), {this->array("[]")}); - this->Assert(ElementWiseMin, this->array("[1, 2, 3, null]"), + this->Assert(MinElementWise, (this->array("[]")), {this->array("[]")}); + this->Assert(MinElementWise, this->array("[1, 2, 3, null]"), {this->array("[1, 2, 3, null]")}); - this->Assert(ElementWiseMin, this->array("[1, 2, 2, 2]"), + this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"), {this->array("[1, 2, 3, 4]"), this->scalar("2")}); - this->Assert(ElementWiseMin, this->array("[1, 2, 2, 2]"), + this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"), {this->array("[1, null, 3, 4]"), this->scalar("2")}); - this->Assert(ElementWiseMin, this->array("[1, 2, 2, 2]"), + this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"), {this->array("[1, null, 3, 4]"), this->scalar("2"), this->scalar("4")}); - this->Assert(ElementWiseMin, this->array("[1, 2, 2, 2]"), + this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"), {this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")}); - this->Assert(ElementWiseMin, this->array("[1, 2, 2, 2]"), + this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"), {this->array("[1, 2, 3, 4]"), this->array("[2, 2, 2, 2]")}); - this->Assert(ElementWiseMin, this->array("[1, 2, 2, 2]"), + this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"), {this->array("[1, 2, 3, 4]"), this->array("[2, null, 2, 2]")}); - this->Assert(ElementWiseMin, this->array("[1, 2, 2, 2]"), + this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"), {this->array("[1, null, 3, 4]"), this->array("[2, 2, 2, 2]")}); - this->Assert(ElementWiseMin, this->array("[1, 2, null, 6]"), + this->Assert(MinElementWise, this->array("[1, 2, null, 6]"), {this->array("[1, 2, null, null]"), this->array("[4, null, null, 6]")}); - this->Assert(ElementWiseMin, this->array("[1, 2, null, 6]"), + this->Assert(MinElementWise, this->array("[1, 2, null, 6]"), {this->array("[4, null, null, 6]"), this->array("[1, 2, null, null]")}); - this->Assert(ElementWiseMin, this->array("[1, 2, 3, 4]"), + this->Assert(MinElementWise, this->array("[1, 2, 3, 4]"), {this->array("[1, 2, 3, 4]"), this->array("[null, null, null, null]")}); - this->Assert(ElementWiseMin, this->array("[1, 2, 3, 4]"), + this->Assert(MinElementWise, this->array("[1, 2, 3, 4]"), {this->array("[null, null, null, null]"), this->array("[1, 2, 3, 4]")}); - this->Assert(ElementWiseMin, this->array("[1, 1, 1, 1]"), + this->Assert(MinElementWise, this->array("[1, 1, 1, 1]"), {this->scalar("1"), this->array("[1, 2, 3, 4]")}); - this->Assert(ElementWiseMin, this->array("[1, 1, 1, 1]"), + this->Assert(MinElementWise, this->array("[1, 1, 1, 1]"), {this->scalar("1"), this->array("[null, null, null, null]")}); - this->Assert(ElementWiseMin, this->array("[1, 1, 1, 1]"), + this->Assert(MinElementWise, this->array("[1, 1, 1, 1]"), {this->scalar("null"), this->array("[1, 1, 1, 1]")}); - this->Assert(ElementWiseMin, this->array("[null, null, null, null]"), + this->Assert(MinElementWise, this->array("[null, null, null, null]"), {this->scalar("null"), this->array("[null, null, null, null]")}); // Test null handling this->element_wise_aggregate_options_.skip_nulls = false; - this->AssertNullScalar(ElementWiseMin, {this->scalar("null"), this->scalar("null")}); - this->AssertNullScalar(ElementWiseMin, {this->scalar("0"), this->scalar("null")}); + this->AssertNullScalar(MinElementWise, {this->scalar("null"), this->scalar("null")}); + this->AssertNullScalar(MinElementWise, {this->scalar("0"), this->scalar("null")}); - this->Assert(ElementWiseMin, this->array("[1, null, 2, 2]"), + this->Assert(MinElementWise, this->array("[1, null, 2, 2]"), {this->array("[1, null, 3, 4]"), this->scalar("2"), this->scalar("4")}); - this->Assert(ElementWiseMin, this->array("[null, null, null, null]"), + this->Assert(MinElementWise, this->array("[null, null, null, null]"), {this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")}); - this->Assert(ElementWiseMin, this->array("[1, null, 2, 2]"), + this->Assert(MinElementWise, this->array("[1, null, 2, 2]"), {this->array("[1, 2, 3, 4]"), this->array("[2, null, 2, 2]")}); - this->Assert(ElementWiseMin, this->array("[null, null, null, null]"), + this->Assert(MinElementWise, this->array("[null, null, null, null]"), {this->scalar("1"), this->array("[null, null, null, null]")}); - this->Assert(ElementWiseMin, this->array("[null, null, null, null]"), + this->Assert(MinElementWise, this->array("[null, null, null, null]"), {this->scalar("null"), this->array("[1, 1, 1, 1]")}); } -TYPED_TEST(TestVarArgsCompareFloating, ElementWiseMin) { +TYPED_TEST(TestVarArgsCompareFloating, MinElementWise) { auto Check = [this](const std::string& expected, const std::vector& inputs) { std::vector args; for (const auto& input : inputs) { args.emplace_back(this->scalar(input)); } - this->Assert(ElementWiseMin, this->scalar(expected), args); + this->Assert(MinElementWise, this->scalar(expected), args); args.clear(); for (const auto& input : inputs) { args.emplace_back(this->array("[" + input + "]")); } - this->Assert(ElementWiseMin, this->array("[" + expected + "]"), args); + this->Assert(MinElementWise, this->array("[" + expected + "]"), args); }; Check("-0.0", {"0.0", "-0.0"}); Check("-0.0", {"1.0", "-0.0", "0.0"}); @@ -828,111 +828,111 @@ TYPED_TEST(TestVarArgsCompareFloating, ElementWiseMin) { Check("-Inf", {"0", "-Inf"}); } -TYPED_TEST(TestVarArgsCompareParametricTemporal, ElementWiseMin) { +TYPED_TEST(TestVarArgsCompareParametricTemporal, MinElementWise) { // Temporal kernel is implemented with numeric kernel underneath - this->AssertNullScalar(ElementWiseMin, {}); - this->AssertNullScalar(ElementWiseMin, {this->scalar("null"), this->scalar("null")}); + this->AssertNullScalar(MinElementWise, {}); + this->AssertNullScalar(MinElementWise, {this->scalar("null"), this->scalar("null")}); - this->Assert(ElementWiseMin, this->scalar("0"), {this->scalar("0")}); - this->Assert(ElementWiseMin, this->scalar("0"), {this->scalar("2"), this->scalar("0")}); - this->Assert(ElementWiseMin, this->scalar("0"), + this->Assert(MinElementWise, this->scalar("0"), {this->scalar("0")}); + this->Assert(MinElementWise, this->scalar("0"), {this->scalar("2"), this->scalar("0")}); + this->Assert(MinElementWise, this->scalar("0"), {this->scalar("0"), this->scalar("null")}); - this->Assert(ElementWiseMin, (this->array("[]")), {this->array("[]")}); - this->Assert(ElementWiseMin, this->array("[1, 2, 3, null]"), + this->Assert(MinElementWise, (this->array("[]")), {this->array("[]")}); + this->Assert(MinElementWise, this->array("[1, 2, 3, null]"), {this->array("[1, 2, 3, null]")}); - this->Assert(ElementWiseMin, this->array("[1, 2, 2, 2]"), + this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"), {this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")}); - this->Assert(ElementWiseMin, this->array("[1, 2, 3, 2]"), + this->Assert(MinElementWise, this->array("[1, 2, 3, 2]"), {this->array("[1, null, 3, 4]"), this->array("[2, 2, null, 2]")}); } -TYPED_TEST(TestVarArgsCompareNumeric, ElementWiseMax) { - this->AssertNullScalar(ElementWiseMax, {}); - this->AssertNullScalar(ElementWiseMax, {this->scalar("null"), this->scalar("null")}); +TYPED_TEST(TestVarArgsCompareNumeric, MaxElementWise) { + this->AssertNullScalar(MaxElementWise, {}); + this->AssertNullScalar(MaxElementWise, {this->scalar("null"), this->scalar("null")}); - this->Assert(ElementWiseMax, this->scalar("0"), {this->scalar("0")}); - this->Assert(ElementWiseMax, this->scalar("2"), + this->Assert(MaxElementWise, this->scalar("0"), {this->scalar("0")}); + this->Assert(MaxElementWise, this->scalar("2"), {this->scalar("2"), this->scalar("0"), this->scalar("1")}); this->Assert( - ElementWiseMax, this->scalar("2"), + MaxElementWise, this->scalar("2"), {this->scalar("2"), this->scalar("0"), this->scalar("1"), this->scalar("null")}); - this->Assert(ElementWiseMax, this->scalar("1"), + this->Assert(MaxElementWise, this->scalar("1"), {this->scalar("null"), this->scalar("null"), this->scalar("1"), this->scalar("null")}); - this->Assert(ElementWiseMax, (this->array("[]")), {this->array("[]")}); - this->Assert(ElementWiseMax, this->array("[1, 2, 3, null]"), + this->Assert(MaxElementWise, (this->array("[]")), {this->array("[]")}); + this->Assert(MaxElementWise, this->array("[1, 2, 3, null]"), {this->array("[1, 2, 3, null]")}); - this->Assert(ElementWiseMax, this->array("[2, 2, 3, 4]"), + this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"), {this->array("[1, 2, 3, 4]"), this->scalar("2")}); - this->Assert(ElementWiseMax, this->array("[2, 2, 3, 4]"), + this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"), {this->array("[1, null, 3, 4]"), this->scalar("2")}); - this->Assert(ElementWiseMax, this->array("[4, 4, 4, 4]"), + this->Assert(MaxElementWise, this->array("[4, 4, 4, 4]"), {this->array("[1, null, 3, 4]"), this->scalar("2"), this->scalar("4")}); - this->Assert(ElementWiseMax, this->array("[2, 2, 3, 4]"), + this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"), {this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")}); - this->Assert(ElementWiseMax, this->array("[2, 2, 3, 4]"), + this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"), {this->array("[1, 2, 3, 4]"), this->array("[2, 2, 2, 2]")}); - this->Assert(ElementWiseMax, this->array("[2, 2, 3, 4]"), + this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"), {this->array("[1, 2, 3, 4]"), this->array("[2, null, 2, 2]")}); - this->Assert(ElementWiseMax, this->array("[2, 2, 3, 4]"), + this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"), {this->array("[1, null, 3, 4]"), this->array("[2, 2, 2, 2]")}); - this->Assert(ElementWiseMax, this->array("[4, 2, null, 6]"), + this->Assert(MaxElementWise, this->array("[4, 2, null, 6]"), {this->array("[1, 2, null, null]"), this->array("[4, null, null, 6]")}); - this->Assert(ElementWiseMax, this->array("[4, 2, null, 6]"), + this->Assert(MaxElementWise, this->array("[4, 2, null, 6]"), {this->array("[4, null, null, 6]"), this->array("[1, 2, null, null]")}); - this->Assert(ElementWiseMax, this->array("[1, 2, 3, 4]"), + this->Assert(MaxElementWise, this->array("[1, 2, 3, 4]"), {this->array("[1, 2, 3, 4]"), this->array("[null, null, null, null]")}); - this->Assert(ElementWiseMax, this->array("[1, 2, 3, 4]"), + this->Assert(MaxElementWise, this->array("[1, 2, 3, 4]"), {this->array("[null, null, null, null]"), this->array("[1, 2, 3, 4]")}); - this->Assert(ElementWiseMax, this->array("[1, 2, 3, 4]"), + this->Assert(MaxElementWise, this->array("[1, 2, 3, 4]"), {this->scalar("1"), this->array("[1, 2, 3, 4]")}); - this->Assert(ElementWiseMax, this->array("[1, 1, 1, 1]"), + this->Assert(MaxElementWise, this->array("[1, 1, 1, 1]"), {this->scalar("1"), this->array("[null, null, null, null]")}); - this->Assert(ElementWiseMax, this->array("[1, 1, 1, 1]"), + this->Assert(MaxElementWise, this->array("[1, 1, 1, 1]"), {this->scalar("null"), this->array("[1, 1, 1, 1]")}); - this->Assert(ElementWiseMax, this->array("[null, null, null, null]"), + this->Assert(MaxElementWise, this->array("[null, null, null, null]"), {this->scalar("null"), this->array("[null, null, null, null]")}); // Test null handling this->element_wise_aggregate_options_.skip_nulls = false; - this->AssertNullScalar(ElementWiseMax, {this->scalar("null"), this->scalar("null")}); - this->AssertNullScalar(ElementWiseMax, {this->scalar("0"), this->scalar("null")}); + this->AssertNullScalar(MaxElementWise, {this->scalar("null"), this->scalar("null")}); + this->AssertNullScalar(MaxElementWise, {this->scalar("0"), this->scalar("null")}); - this->Assert(ElementWiseMax, this->array("[4, null, 4, 4]"), + this->Assert(MaxElementWise, this->array("[4, null, 4, 4]"), {this->array("[1, null, 3, 4]"), this->scalar("2"), this->scalar("4")}); - this->Assert(ElementWiseMax, this->array("[null, null, null, null]"), + this->Assert(MaxElementWise, this->array("[null, null, null, null]"), {this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")}); - this->Assert(ElementWiseMax, this->array("[2, null, 3, 4]"), + this->Assert(MaxElementWise, this->array("[2, null, 3, 4]"), {this->array("[1, 2, 3, 4]"), this->array("[2, null, 2, 2]")}); - this->Assert(ElementWiseMax, this->array("[null, null, null, null]"), + this->Assert(MaxElementWise, this->array("[null, null, null, null]"), {this->scalar("1"), this->array("[null, null, null, null]")}); - this->Assert(ElementWiseMax, this->array("[null, null, null, null]"), + this->Assert(MaxElementWise, this->array("[null, null, null, null]"), {this->scalar("null"), this->array("[1, 1, 1, 1]")}); } -TYPED_TEST(TestVarArgsCompareFloating, ElementWiseMax) { +TYPED_TEST(TestVarArgsCompareFloating, MaxElementWise) { auto Check = [this](const std::string& expected, const std::vector& inputs) { std::vector args; for (const auto& input : inputs) { args.emplace_back(this->scalar(input)); } - this->Assert(ElementWiseMax, this->scalar(expected), args); + this->Assert(MaxElementWise, this->scalar(expected), args); args.clear(); for (const auto& input : inputs) { args.emplace_back(this->array("[" + input + "]")); } - this->Assert(ElementWiseMax, this->array("[" + expected + "]"), args); + this->Assert(MaxElementWise, this->array("[" + expected + "]"), args); }; Check("0.0", {"0.0", "-0.0"}); Check("1.0", {"1.0", "-0.0", "0.0"}); @@ -948,34 +948,34 @@ TYPED_TEST(TestVarArgsCompareFloating, ElementWiseMax) { Check("0", {"0", "-Inf"}); } -TYPED_TEST(TestVarArgsCompareParametricTemporal, ElementWiseMax) { +TYPED_TEST(TestVarArgsCompareParametricTemporal, MaxElementWise) { // Temporal kernel is implemented with numeric kernel underneath - this->AssertNullScalar(ElementWiseMax, {}); - this->AssertNullScalar(ElementWiseMax, {this->scalar("null"), this->scalar("null")}); + this->AssertNullScalar(MaxElementWise, {}); + this->AssertNullScalar(MaxElementWise, {this->scalar("null"), this->scalar("null")}); - this->Assert(ElementWiseMax, this->scalar("0"), {this->scalar("0")}); - this->Assert(ElementWiseMax, this->scalar("2"), {this->scalar("2"), this->scalar("0")}); - this->Assert(ElementWiseMax, this->scalar("0"), + this->Assert(MaxElementWise, this->scalar("0"), {this->scalar("0")}); + this->Assert(MaxElementWise, this->scalar("2"), {this->scalar("2"), this->scalar("0")}); + this->Assert(MaxElementWise, this->scalar("0"), {this->scalar("0"), this->scalar("null")}); - this->Assert(ElementWiseMax, (this->array("[]")), {this->array("[]")}); - this->Assert(ElementWiseMax, this->array("[1, 2, 3, null]"), + this->Assert(MaxElementWise, (this->array("[]")), {this->array("[]")}); + this->Assert(MaxElementWise, this->array("[1, 2, 3, null]"), {this->array("[1, 2, 3, null]")}); - this->Assert(ElementWiseMax, this->array("[2, 2, 3, 4]"), + this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"), {this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")}); - this->Assert(ElementWiseMax, this->array("[2, 2, 3, 4]"), + this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"), {this->array("[1, null, 3, 4]"), this->array("[2, 2, null, 2]")}); } -TEST(TestElementWiseMaxElementWiseMin, CommonTimestamp) { +TEST(TestMaxElementWiseMinElementWise, CommonTimestamp) { { auto t1 = std::make_shared(TimeUnit::SECOND); auto t2 = std::make_shared(TimeUnit::MILLI); auto expected = MakeScalar(t2, 1000).ValueOrDie(); ASSERT_OK_AND_ASSIGN(auto actual, - ElementWiseMin({Datum(MakeScalar(t1, 1).ValueOrDie()), + MinElementWise({Datum(MakeScalar(t1, 1).ValueOrDie()), Datum(MakeScalar(t2, 12000).ValueOrDie())})); AssertScalarsEqual(*expected, *actual.scalar(), /*verbose=*/true); } @@ -984,7 +984,7 @@ TEST(TestElementWiseMaxElementWiseMin, CommonTimestamp) { auto t2 = std::make_shared(TimeUnit::SECOND); auto expected = MakeScalar(t2, 86401).ValueOrDie(); ASSERT_OK_AND_ASSIGN(auto actual, - ElementWiseMax({Datum(MakeScalar(t1, 1).ValueOrDie()), + MaxElementWise({Datum(MakeScalar(t1, 1).ValueOrDie()), Datum(MakeScalar(t2, 86401).ValueOrDie())})); AssertScalarsEqual(*expected, *actual.scalar(), /*verbose=*/true); } @@ -994,7 +994,7 @@ TEST(TestElementWiseMaxElementWiseMin, CommonTimestamp) { auto t3 = std::make_shared(TimeUnit::SECOND); auto expected = MakeScalar(t3, 86400).ValueOrDie(); ASSERT_OK_AND_ASSIGN( - auto actual, ElementWiseMin({Datum(MakeScalar(t1, 1).ValueOrDie()), + auto actual, MinElementWise({Datum(MakeScalar(t1, 1).ValueOrDie()), Datum(MakeScalar(t2, 2 * 86400000).ValueOrDie())})); AssertScalarsEqual(*expected, *actual.scalar(), /*verbose=*/true); } diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index cd054fcea0e..3f63bf2c405 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -3344,12 +3344,227 @@ struct BinaryJoin { } }; +using BinaryJoinElementWiseState = OptionsWrapper; + +template +struct BinaryJoinElementWise { + using ArrayType = typename TypeTraits::ArrayType; + using BuilderType = typename TypeTraits::BuilderType; + using offset_type = typename Type::offset_type; + + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + JoinOptions options = BinaryJoinElementWiseState::Get(ctx); + // Last argument is the separator (for consistency with binary_join) + if (std::all_of(batch.values.begin(), batch.values.end(), + [](const Datum& d) { return d.is_scalar(); })) { + return ExecOnlyScalar(ctx, options, batch, out); + } + return ExecContainingArrays(ctx, options, batch, out); + } + + static Status ExecOnlyScalar(KernelContext* ctx, const JoinOptions& options, + const ExecBatch& batch, Datum* out) { + BaseBinaryScalar* output = checked_cast(out->scalar().get()); + const size_t num_args = batch.values.size(); + if (num_args == 1) { + // Only separator, no values + ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(0)); + output->is_valid = batch.values[0].scalar()->is_valid; + return Status::OK(); + } + + int64_t final_size = CalculateRowSize(options, batch, 0); + if (final_size < 0) { + ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(0)); + output->is_valid = false; + return Status::OK(); + } + ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(final_size)); + const auto separator = UnboxScalar::Unbox(*batch.values.back().scalar()); + uint8_t* buf = output->value->mutable_data(); + bool first = true; + for (size_t i = 0; i < num_args - 1; i++) { + const Scalar& scalar = *batch[i].scalar(); + util::string_view s; + if (scalar.is_valid) { + s = UnboxScalar::Unbox(scalar); + } else { + switch (options.null_handling) { + case JoinOptions::EMIT_NULL: + // Handled by CalculateRowSize + DCHECK(false) << "unreachable"; + break; + case JoinOptions::SKIP: + continue; + case JoinOptions::REPLACE: + s = options.null_replacement; + break; + } + } + if (!first) { + buf = std::copy(separator.begin(), separator.end(), buf); + } + first = false; + buf = std::copy(s.begin(), s.end(), buf); + } + output->is_valid = true; + DCHECK_EQ(final_size, buf - output->value->mutable_data()); + return Status::OK(); + } + + static Status ExecContainingArrays(KernelContext* ctx, const JoinOptions& options, + const ExecBatch& batch, Datum* out) { + // Presize data to avoid reallocations + int64_t final_size = 0; + for (int64_t i = 0; i < batch.length; i++) { + auto size = CalculateRowSize(options, batch, i); + if (size > 0) final_size += size; + } + BuilderType builder(ctx->memory_pool()); + RETURN_NOT_OK(builder.Reserve(batch.length)); + RETURN_NOT_OK(builder.ReserveData(final_size)); + + std::vector valid_cols(batch.values.size()); + for (size_t row = 0; row < static_cast(batch.length); row++) { + size_t num_valid = 0; // Not counting separator + for (size_t col = 0; col < batch.values.size(); col++) { + if (batch[col].is_scalar()) { + const auto& scalar = *batch[col].scalar(); + if (scalar.is_valid) { + valid_cols[col] = UnboxScalar::Unbox(scalar); + if (col < batch.values.size() - 1) num_valid++; + } else { + valid_cols[col] = util::string_view(); + } + } else { + const ArrayData& array = *batch[col].array(); + if (!array.MayHaveNulls() || + BitUtil::GetBit(array.buffers[0]->data(), array.offset + row)) { + const offset_type* offsets = array.GetValues(1); + const uint8_t* data = array.GetValues(2, /*absolute_offset=*/0); + const int64_t length = offsets[row + 1] - offsets[row]; + valid_cols[col] = util::string_view( + reinterpret_cast(data + offsets[row]), length); + if (col < batch.values.size() - 1) num_valid++; + } else { + valid_cols[col] = util::string_view(); + } + } + } + + if (!valid_cols.back().data()) { + // Separator is null + builder.UnsafeAppendNull(); + continue; + } else if (batch.values.size() == 1) { + // Only given separator + builder.UnsafeAppendEmptyValue(); + continue; + } else if (num_valid < batch.values.size() - 1) { + // We had some nulls + if (options.null_handling == JoinOptions::EMIT_NULL) { + builder.UnsafeAppendNull(); + continue; + } + } + const auto separator = valid_cols.back(); + bool first = true; + for (size_t col = 0; col < batch.values.size() - 1; col++) { + util::string_view value = valid_cols[col]; + if (!value.data()) { + switch (options.null_handling) { + case JoinOptions::EMIT_NULL: + DCHECK(false) << "unreachable"; + break; + case JoinOptions::SKIP: + continue; + case JoinOptions::REPLACE: + value = options.null_replacement; + break; + } + } + if (first) { + builder.UnsafeAppend(value); + first = false; + continue; + } + builder.UnsafeExtendCurrent(separator); + builder.UnsafeExtendCurrent(value); + } + } + + std::shared_ptr string_array; + RETURN_NOT_OK(builder.Finish(&string_array)); + *out = *string_array->data(); + out->mutable_array()->type = batch[0].type(); + DCHECK_EQ(batch.length, out->array()->length); + DCHECK_EQ(final_size, + checked_cast(*string_array).total_values_length()); + return Status::OK(); + } + + // Compute the length of the output for the given position, or -1 if it would be null. + static int64_t CalculateRowSize(const JoinOptions& options, const ExecBatch& batch, + const int64_t index) { + const auto num_args = batch.values.size(); + int64_t final_size = 0; + int64_t num_non_null_args = 0; + for (size_t i = 0; i < num_args; i++) { + int64_t element_size = 0; + bool valid = true; + if (batch[i].is_scalar()) { + const Scalar& scalar = *batch[i].scalar(); + valid = scalar.is_valid; + element_size = UnboxScalar::Unbox(scalar).size(); + } else { + const ArrayData& array = *batch[i].array(); + valid = !array.MayHaveNulls() || + BitUtil::GetBit(array.buffers[0]->data(), array.offset + index); + const offset_type* offsets = array.GetValues(1); + element_size = offsets[index + 1] - offsets[index]; + } + if (i == num_args - 1) { + if (!valid) return -1; + if (num_non_null_args > 1) { + // Add separator size (only if there were values to join) + final_size += (num_non_null_args - 1) * element_size; + } + break; + } + if (!valid) { + switch (options.null_handling) { + case JoinOptions::EMIT_NULL: + return -1; + case JoinOptions::SKIP: + continue; + case JoinOptions::REPLACE: + element_size = options.null_replacement.size(); + break; + } + } + num_non_null_args++; + final_size += element_size; + } + return final_size; + } +}; + const FunctionDoc binary_join_doc( "Join a list of strings together with a `separator` to form a single string", ("Insert `separator` between `list` elements, and concatenate them.\n" "Any null input and any null `list` element emits a null output.\n"), {"list", "separator"}); +const FunctionDoc binary_join_element_wise_doc( + "Join string arguments into one, using the last argument as the separator", + ("Insert the last argument of `strings` between the rest of the elements, " + "and concatenate them.\n" + "Any null separator element emits a null output. Null elements either " + "emit a null (the default), are skipped, or replaced with a given string.\n"), + {"*strings"}, "JoinOptions"); + +const auto kDefaultJoinOptions = JoinOptions::Defaults(); + template void AddBinaryJoinForListType(ScalarFunction* func) { for (const std::shared_ptr& ty : BaseBinaryTypes()) { @@ -3360,11 +3575,25 @@ void AddBinaryJoinForListType(ScalarFunction* func) { } void AddBinaryJoin(FunctionRegistry* registry) { - auto func = - std::make_shared("binary_join", Arity::Binary(), &binary_join_doc); - AddBinaryJoinForListType(func.get()); - AddBinaryJoinForListType(func.get()); - DCHECK_OK(registry->AddFunction(std::move(func))); + { + auto func = std::make_shared("binary_join", Arity::Binary(), + &binary_join_doc); + AddBinaryJoinForListType(func.get()); + AddBinaryJoinForListType(func.get()); + DCHECK_OK(registry->AddFunction(std::move(func))); + } + { + auto func = std::make_shared( + "binary_join_element_wise", Arity::VarArgs(/*min_args=*/1), + &binary_join_element_wise_doc, &kDefaultJoinOptions); + for (const auto& ty : BaseBinaryTypes()) { + DCHECK_OK( + func->AddKernel({InputType(ty)}, ty, + GenerateTypeAgnosticVarBinaryBase(ty), + BinaryJoinElementWiseState::Init)); + } + DCHECK_OK(registry->AddFunction(std::move(func))); + } } template