Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cpp/src/arrow/compute/api_scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ SCALAR_ARITHMETIC_BINARY(Multiply, "multiply", "multiply_checked")
SCALAR_ARITHMETIC_BINARY(Divide, "divide", "divide_checked")
SCALAR_ARITHMETIC_BINARY(Power, "power", "power_checked")

Result<Datum> ElementWiseMax(const std::vector<Datum>& args,
Result<Datum> MaxElementWise(const std::vector<Datum>& args,
ElementWiseAggregateOptions options, ExecContext* ctx) {
return CallFunction("element_wise_max", args, &options, ctx);
return CallFunction("max_element_wise", args, &options, ctx);
}

Result<Datum> ElementWiseMin(const std::vector<Datum>& args,
Result<Datum> MinElementWise(const std::vector<Datum>& args,
ElementWiseAggregateOptions options, ExecContext* ctx) {
return CallFunction("element_wise_min", args, &options, ctx);
return CallFunction("min_element_wise", args, &options, ctx);
}

// ----------------------------------------------------------------------
Expand Down
23 changes: 21 additions & 2 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,25 @@ struct ARROW_EXPORT ElementWiseAggregateOptions : public FunctionOptions {
bool skip_nulls;
};

/// Options for var_args_join.
struct ARROW_EXPORT JoinOptions : public FunctionOptions {
/// How to handle null values. (A null separator always results in a null output.)
enum NullHandlingBehavior {
/// A null in any input results in a null in the output.
EMIT_NULL,
/// Nulls in inputs are skipped.
SKIP,
/// Nulls in inputs are replaced with the replacement string.
REPLACE,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, but I think we should avoid ALL_CAPS because of potential conflicts with third-party macros. What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to stay consistent with the existing enums. (Also see: the whole ML discussion…)

If we reach a consensus there I'm happy to rename all the enums.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's some other things I noticed in ARROW-13025 like a toplevel enum (not enum class).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think FunctionOptions classes and enums are recent enough that we may want to do a cleanup pass on them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(as a separate JIRA, of course!)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

};
explicit JoinOptions(NullHandlingBehavior null_handling = EMIT_NULL,
std::string null_replacement = "")
: null_handling(null_handling), null_replacement(std::move(null_replacement)) {}
static JoinOptions Defaults() { return JoinOptions(); }
NullHandlingBehavior null_handling;
std::string null_replacement;
};

struct ARROW_EXPORT MatchSubstringOptions : public FunctionOptions {
explicit MatchSubstringOptions(std::string pattern, bool ignore_case = false)
: pattern(std::move(pattern)), ignore_case(ignore_case) {}
Expand Down Expand Up @@ -287,7 +306,7 @@ Result<Datum> Power(const Datum& left, const Datum& right,
/// \param[in] ctx the function execution context, optional
/// \return the element-wise maximum
ARROW_EXPORT
Result<Datum> ElementWiseMax(
Result<Datum> MaxElementWise(
const std::vector<Datum>& args,
ElementWiseAggregateOptions options = ElementWiseAggregateOptions::Defaults(),
ExecContext* ctx = NULLPTR);
Expand All @@ -300,7 +319,7 @@ Result<Datum> ElementWiseMax(
/// \param[in] ctx the function execution context, optional
/// \return the element-wise minimum
ARROW_EXPORT
Result<Datum> ElementWiseMin(
Result<Datum> MinElementWise(
const std::vector<Datum>& args,
ElementWiseAggregateOptions options = ElementWiseAggregateOptions::Defaults(),
ExecContext* ctx = NULLPTR);
Expand Down
16 changes: 8 additions & 8 deletions cpp/src/arrow/compute/kernels/scalar_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,14 +467,14 @@ const FunctionDoc less_equal_doc{
("A null on either side emits a null comparison result."),
{"x", "y"}};

const FunctionDoc element_wise_min_doc{
const FunctionDoc min_element_wise_doc{
"Find the element-wise minimum value",
("Nulls will be ignored (default) or propagated. "
"NaN will be taken over null, but not over any valid float."),
{"*args"},
"ElementWiseAggregateOptions"};

const FunctionDoc element_wise_max_doc{
const FunctionDoc max_element_wise_doc{
"Find the element-wise maximum value",
("Nulls will be ignored (default) or propagated. "
"NaN will be taken over null, but not over any valid float."),
Expand All @@ -501,13 +501,13 @@ void RegisterScalarComparison(FunctionRegistry* registry) {
// ----------------------------------------------------------------------
// Variadic element-wise functions

auto element_wise_min =
MakeScalarMinMax<Minimum>("element_wise_min", &element_wise_min_doc);
DCHECK_OK(registry->AddFunction(std::move(element_wise_min)));
auto min_element_wise =
MakeScalarMinMax<Minimum>("min_element_wise", &min_element_wise_doc);
DCHECK_OK(registry->AddFunction(std::move(min_element_wise)));

auto element_wise_max =
MakeScalarMinMax<Maximum>("element_wise_max", &element_wise_max_doc);
DCHECK_OK(registry->AddFunction(std::move(element_wise_max)));
auto max_element_wise =
MakeScalarMinMax<Maximum>("max_element_wise", &max_element_wise_doc);
DCHECK_OK(registry->AddFunction(std::move(max_element_wise)));
}

} // namespace internal
Expand Down
184 changes: 92 additions & 92 deletions cpp/src/arrow/compute/kernels/scalar_compare_test.cc

Large diffs are not rendered by default.

239 changes: 234 additions & 5 deletions cpp/src/arrow/compute/kernels/scalar_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3344,12 +3344,227 @@ struct BinaryJoin {
}
};

using BinaryJoinElementWiseState = OptionsWrapper<JoinOptions>;

template <typename Type>
struct BinaryJoinElementWise {
using ArrayType = typename TypeTraits<Type>::ArrayType;
using BuilderType = typename TypeTraits<Type>::BuilderType;
using offset_type = typename Type::offset_type;

static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
JoinOptions options = BinaryJoinElementWiseState::Get(ctx);
// Last argument is the separator (for consistency with binary_join)
if (std::all_of(batch.values.begin(), batch.values.end(),
[](const Datum& d) { return d.is_scalar(); })) {
return ExecOnlyScalar(ctx, options, batch, out);
}
return ExecContainingArrays(ctx, options, batch, out);
}

static Status ExecOnlyScalar(KernelContext* ctx, const JoinOptions& options,
const ExecBatch& batch, Datum* out) {
BaseBinaryScalar* output = checked_cast<BaseBinaryScalar*>(out->scalar().get());
const size_t num_args = batch.values.size();
if (num_args == 1) {
// Only separator, no values
ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(0));
output->is_valid = batch.values[0].scalar()->is_valid;
return Status::OK();
}

int64_t final_size = CalculateRowSize(options, batch, 0);
if (final_size < 0) {
ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(0));
output->is_valid = false;
return Status::OK();
}
ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(final_size));
const auto separator = UnboxScalar<Type>::Unbox(*batch.values.back().scalar());
uint8_t* buf = output->value->mutable_data();
bool first = true;
for (size_t i = 0; i < num_args - 1; i++) {
const Scalar& scalar = *batch[i].scalar();
util::string_view s;
if (scalar.is_valid) {
s = UnboxScalar<Type>::Unbox(scalar);
} else {
switch (options.null_handling) {
case JoinOptions::EMIT_NULL:
// Handled by CalculateRowSize
DCHECK(false) << "unreachable";
break;
case JoinOptions::SKIP:
continue;
case JoinOptions::REPLACE:
s = options.null_replacement;
break;
}
}
if (!first) {
buf = std::copy(separator.begin(), separator.end(), buf);
}
first = false;
buf = std::copy(s.begin(), s.end(), buf);
}
output->is_valid = true;
DCHECK_EQ(final_size, buf - output->value->mutable_data());
return Status::OK();
}

static Status ExecContainingArrays(KernelContext* ctx, const JoinOptions& options,
const ExecBatch& batch, Datum* out) {
// Presize data to avoid reallocations
int64_t final_size = 0;
for (int64_t i = 0; i < batch.length; i++) {
auto size = CalculateRowSize(options, batch, i);
if (size > 0) final_size += size;
}
BuilderType builder(ctx->memory_pool());
RETURN_NOT_OK(builder.Reserve(batch.length));
RETURN_NOT_OK(builder.ReserveData(final_size));

std::vector<util::string_view> valid_cols(batch.values.size());
for (size_t row = 0; row < static_cast<size_t>(batch.length); row++) {
size_t num_valid = 0; // Not counting separator
for (size_t col = 0; col < batch.values.size(); col++) {
if (batch[col].is_scalar()) {
const auto& scalar = *batch[col].scalar();
if (scalar.is_valid) {
valid_cols[col] = UnboxScalar<Type>::Unbox(scalar);
if (col < batch.values.size() - 1) num_valid++;
} else {
valid_cols[col] = util::string_view();
}
} else {
const ArrayData& array = *batch[col].array();
if (!array.MayHaveNulls() ||
BitUtil::GetBit(array.buffers[0]->data(), array.offset + row)) {
const offset_type* offsets = array.GetValues<offset_type>(1);
const uint8_t* data = array.GetValues<uint8_t>(2, /*absolute_offset=*/0);
const int64_t length = offsets[row + 1] - offsets[row];
valid_cols[col] = util::string_view(
reinterpret_cast<const char*>(data + offsets[row]), length);
if (col < batch.values.size() - 1) num_valid++;
} else {
valid_cols[col] = util::string_view();
}
}
}

if (!valid_cols.back().data()) {
// Separator is null
builder.UnsafeAppendNull();
continue;
} else if (batch.values.size() == 1) {
// Only given separator
builder.UnsafeAppendEmptyValue();
continue;
} else if (num_valid < batch.values.size() - 1) {
// We had some nulls
if (options.null_handling == JoinOptions::EMIT_NULL) {
builder.UnsafeAppendNull();
continue;
}
}
const auto separator = valid_cols.back();
bool first = true;
for (size_t col = 0; col < batch.values.size() - 1; col++) {
util::string_view value = valid_cols[col];
if (!value.data()) {
switch (options.null_handling) {
case JoinOptions::EMIT_NULL:
DCHECK(false) << "unreachable";
break;
case JoinOptions::SKIP:
continue;
case JoinOptions::REPLACE:
value = options.null_replacement;
break;
}
}
if (first) {
builder.UnsafeAppend(value);
first = false;
continue;
}
builder.UnsafeExtendCurrent(separator);
builder.UnsafeExtendCurrent(value);
}
}

std::shared_ptr<Array> string_array;
RETURN_NOT_OK(builder.Finish(&string_array));
*out = *string_array->data();
out->mutable_array()->type = batch[0].type();
DCHECK_EQ(batch.length, out->array()->length);
DCHECK_EQ(final_size,
checked_cast<const ArrayType&>(*string_array).total_values_length());
return Status::OK();
}

// Compute the length of the output for the given position, or -1 if it would be null.
static int64_t CalculateRowSize(const JoinOptions& options, const ExecBatch& batch,
const int64_t index) {
const auto num_args = batch.values.size();
int64_t final_size = 0;
int64_t num_non_null_args = 0;
for (size_t i = 0; i < num_args; i++) {
int64_t element_size = 0;
bool valid = true;
if (batch[i].is_scalar()) {
const Scalar& scalar = *batch[i].scalar();
valid = scalar.is_valid;
element_size = UnboxScalar<Type>::Unbox(scalar).size();
} else {
const ArrayData& array = *batch[i].array();
valid = !array.MayHaveNulls() ||
BitUtil::GetBit(array.buffers[0]->data(), array.offset + index);
const offset_type* offsets = array.GetValues<offset_type>(1);
element_size = offsets[index + 1] - offsets[index];
}
if (i == num_args - 1) {
if (!valid) return -1;
if (num_non_null_args > 1) {
// Add separator size (only if there were values to join)
final_size += (num_non_null_args - 1) * element_size;
}
break;
}
if (!valid) {
switch (options.null_handling) {
case JoinOptions::EMIT_NULL:
return -1;
case JoinOptions::SKIP:
continue;
case JoinOptions::REPLACE:
element_size = options.null_replacement.size();
break;
}
}
num_non_null_args++;
final_size += element_size;
}
return final_size;
}
};

const FunctionDoc binary_join_doc(
"Join a list of strings together with a `separator` to form a single string",
("Insert `separator` between `list` elements, and concatenate them.\n"
"Any null input and any null `list` element emits a null output.\n"),
{"list", "separator"});

const FunctionDoc binary_join_element_wise_doc(
"Join string arguments into one, using the last argument as the separator",
("Insert the last argument of `strings` between the rest of the elements, "
"and concatenate them.\n"
"Any null separator element emits a null output. Null elements either "
"emit a null (the default), are skipped, or replaced with a given string.\n"),
{"*strings"}, "JoinOptions");

const auto kDefaultJoinOptions = JoinOptions::Defaults();

template <typename ListType>
void AddBinaryJoinForListType(ScalarFunction* func) {
for (const std::shared_ptr<DataType>& ty : BaseBinaryTypes()) {
Expand All @@ -3360,11 +3575,25 @@ void AddBinaryJoinForListType(ScalarFunction* func) {
}

void AddBinaryJoin(FunctionRegistry* registry) {
auto func =
std::make_shared<ScalarFunction>("binary_join", Arity::Binary(), &binary_join_doc);
AddBinaryJoinForListType<ListType>(func.get());
AddBinaryJoinForListType<LargeListType>(func.get());
DCHECK_OK(registry->AddFunction(std::move(func)));
{
auto func = std::make_shared<ScalarFunction>("binary_join", Arity::Binary(),
&binary_join_doc);
AddBinaryJoinForListType<ListType>(func.get());
AddBinaryJoinForListType<LargeListType>(func.get());
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
auto func = std::make_shared<ScalarFunction>(
"binary_join_element_wise", Arity::VarArgs(/*min_args=*/1),
&binary_join_element_wise_doc, &kDefaultJoinOptions);
for (const auto& ty : BaseBinaryTypes()) {
DCHECK_OK(
func->AddKernel({InputType(ty)}, ty,
GenerateTypeAgnosticVarBinaryBase<BinaryJoinElementWise>(ty),
BinaryJoinElementWiseState::Init));
}
DCHECK_OK(registry->AddFunction(std::move(func)));
}
}

template <template <typename> class ExecFunctor>
Expand Down
Loading