Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ std::vector<std::shared_ptr<DataType>> g_floating_types;
std::vector<std::shared_ptr<DataType>> g_numeric_types;
std::vector<std::shared_ptr<DataType>> g_base_binary_types;
std::vector<std::shared_ptr<DataType>> g_temporal_types;
std::vector<std::shared_ptr<DataType>> g_interval_types;
std::vector<std::shared_ptr<DataType>> g_primitive_types;
std::vector<Type::type> g_decimal_type_ids;
static std::once_flag codegen_static_initialized;
Expand Down Expand Up @@ -91,6 +92,9 @@ static void InitStaticData() {
timestamp(TimeUnit::MICRO),
timestamp(TimeUnit::NANO)};

// Interval types
g_interval_types = {day_time_interval(), month_interval()};

// Base binary types (without FixedSizeBinary)
g_base_binary_types = {binary(), utf8(), large_binary(), large_utf8()};

Expand Down Expand Up @@ -157,6 +161,11 @@ const std::vector<std::shared_ptr<DataType>>& TemporalTypes() {
return g_temporal_types;
}

const std::vector<std::shared_ptr<DataType>>& IntervalTypes() {
std::call_once(codegen_static_initialized, InitStaticData);
return g_interval_types;
}

const std::vector<std::shared_ptr<DataType>>& PrimitiveTypes() {
std::call_once(codegen_static_initialized, InitStaticData);
return g_primitive_types;
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,9 @@ const std::vector<std::shared_ptr<DataType>>& NumericTypes();
// Temporal types including time and timestamps for each unit
const std::vector<std::shared_ptr<DataType>>& TemporalTypes();

// Interval types
const std::vector<std::shared_ptr<DataType>>& IntervalTypes();

// Integer, floating point, base binary, and temporal
const std::vector<std::shared_ptr<DataType>>& PrimitiveTypes();

Expand Down
278 changes: 273 additions & 5 deletions cpp/src/arrow/compute/kernels/scalar_if_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,28 @@ void CopyOneArrayValue(const DataType& type, const uint8_t* in_valid,
out_offset);
}

template <typename Type>
void CopyOneScalarValue(const Scalar& scalar, uint8_t* out_valid, uint8_t* out_values,
const int64_t out_offset) {
if (out_valid) {
BitUtil::SetBitTo(out_valid, out_offset, scalar.is_valid);
}
CopyFixedWidth<Type>::CopyScalar(scalar, /*length=*/1, out_values, out_offset);
}

template <typename Type>
void CopyOneValue(const Datum& in_values, const int64_t in_offset, uint8_t* out_valid,
uint8_t* out_values, const int64_t out_offset) {
if (in_values.is_array()) {
const ArrayData& array = *in_values.array();
CopyOneArrayValue<Type>(*array.type, array.GetValues<uint8_t>(0, 0),
array.GetValues<uint8_t>(1, 0), array.offset + in_offset,
out_valid, out_values, out_offset);
} else {
CopyOneScalarValue<Type>(*in_values.scalar(), out_valid, out_values, out_offset);
}
}

struct CaseWhenFunction : ScalarFunction {
using ScalarFunction::ScalarFunction;

Expand Down Expand Up @@ -1606,6 +1628,206 @@ struct CoalesceFunctor<Type, enable_if_base_binary<Type>> {
}
};

template <typename Type>
Status ExecScalarChoose(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
const auto& index_scalar = *batch[0].scalar();
if (!index_scalar.is_valid) {
if (out->is_array()) {
auto source = MakeNullScalar(out->type());
ArrayData* output = out->mutable_array();
CopyValues<Type>(source, /*row=*/0, batch.length,
output->GetMutableValues<uint8_t>(0, /*absolute_offset=*/0),
output->GetMutableValues<uint8_t>(1, /*absolute_offset=*/0),
output->offset);
}
return Status::OK();
}
auto index = UnboxScalar<Int64Type>::Unbox(index_scalar);
if (index < 0 || static_cast<size_t>(index + 1) >= batch.values.size()) {
return Status::IndexError("choose: index ", index, " out of range");
}
auto source = batch.values[index + 1];
if (out->is_scalar()) {
*out = source;
} else {
ArrayData* output = out->mutable_array();
CopyValues<Type>(source, /*row=*/0, batch.length,
output->GetMutableValues<uint8_t>(0, /*absolute_offset=*/0),
output->GetMutableValues<uint8_t>(1, /*absolute_offset=*/0),
output->offset);
}
return Status::OK();
}

template <typename Type>
Status ExecArrayChoose(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
ArrayData* output = out->mutable_array();
const int64_t out_offset = output->offset;
// Need a null bitmap if any input has nulls
uint8_t* out_valid = nullptr;
if (std::any_of(batch.values.begin(), batch.values.end(),
[](const Datum& d) { return d.null_count() > 0; })) {
out_valid = output->buffers[0]->mutable_data();
} else {
BitUtil::SetBitsTo(output->buffers[0]->mutable_data(), out_offset, batch.length,
true);
}
uint8_t* out_values = output->buffers[1]->mutable_data();
int64_t row = 0;
return VisitArrayValuesInline<Int64Type>(
*batch[0].array(),
[&](int64_t index) {
if (index < 0 || static_cast<size_t>(index + 1) >= batch.values.size()) {
return Status::IndexError("choose: index ", index, " out of range");
}
const auto& source = batch.values[index + 1];
CopyOneValue<Type>(source, row, out_valid, out_values, out_offset + row);
row++;
return Status::OK();
},
[&]() {
// Index is null, but we should still initialize the output with some value
const auto& source = batch.values[1];
CopyOneValue<Type>(source, row, out_valid, out_values, out_offset + row);
BitUtil::ClearBit(out_valid, out_offset + row);
row++;
return Status::OK();
});
}

template <typename Type, typename Enable = void>
struct ChooseFunctor {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
if (batch.values[0].is_scalar()) {
return ExecScalarChoose<Type>(ctx, batch, out);
}
return ExecArrayChoose<Type>(ctx, batch, out);
}
};

template <>
struct ChooseFunctor<NullType> {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
return Status::OK();
}
};

template <typename Type>
struct ChooseFunctor<Type, enable_if_base_binary<Type>> {
using offset_type = typename Type::offset_type;
using BuilderType = typename TypeTraits<Type>::BuilderType;

static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
if (batch.values[0].is_scalar()) {
const auto& index_scalar = *batch[0].scalar();
if (!index_scalar.is_valid) {
if (out->is_array()) {
ARROW_ASSIGN_OR_RAISE(
auto temp_array,
MakeArrayOfNull(out->type(), batch.length, ctx->memory_pool()));
*out->mutable_array() = *temp_array->data();
}
return Status::OK();
}
auto index = UnboxScalar<Int64Type>::Unbox(index_scalar);
if (index < 0 || static_cast<size_t>(index + 1) >= batch.values.size()) {
return Status::IndexError("choose: index ", index, " out of range");
}
auto source = batch.values[index + 1];
if (source.is_scalar() && out->is_array()) {
ARROW_ASSIGN_OR_RAISE(
auto temp_array,
MakeArrayFromScalar(*source.scalar(), batch.length, ctx->memory_pool()));
*out->mutable_array() = *temp_array->data();
} else {
*out = source;
}
return Status::OK();
}

// Row-wise implementation
BuilderType builder(out->type(), ctx->memory_pool());
RETURN_NOT_OK(builder.Reserve(batch.length));
int64_t reserve_data = 0;
for (const auto& value : batch.values) {
if (value.is_scalar()) {
if (!value.scalar()->is_valid) continue;
const auto row_length =
checked_cast<const BaseBinaryScalar&>(*value.scalar()).value->size();
reserve_data = std::max<int64_t>(reserve_data, batch.length * row_length);
continue;
}
const ArrayData& arr = *value.array();
const offset_type* offsets = arr.GetValues<offset_type>(1);
const offset_type values_length = offsets[arr.length] - offsets[0];
reserve_data = std::max<int64_t>(reserve_data, values_length);
}
RETURN_NOT_OK(builder.ReserveData(reserve_data));
int64_t row = 0;
RETURN_NOT_OK(VisitArrayValuesInline<Int64Type>(
*batch[0].array(),
[&](int64_t index) {
if (index < 0 || static_cast<size_t>(index + 1) >= batch.values.size()) {
return Status::IndexError("choose: index ", index, " out of range");
}
const auto& source = batch.values[index + 1];
return CopyValue(source, &builder, row++);
},
[&]() {
row++;
return builder.AppendNull();
}));
auto actual_type = out->type();
std::shared_ptr<Array> temp_output;
RETURN_NOT_OK(builder.Finish(&temp_output));
ArrayData* output = out->mutable_array();
*output = *temp_output->data();
// Builder type != logical type due to GenerateTypeAgnosticVarBinaryBase
output->type = std::move(actual_type);
return Status::OK();
}

static Status CopyValue(const Datum& datum, BuilderType* builder, int64_t row) {
if (datum.is_scalar()) {
const auto& scalar = checked_cast<const BaseBinaryScalar&>(*datum.scalar());
if (!scalar.value) return builder->AppendNull();
return builder->Append(scalar.value->data(),
static_cast<offset_type>(scalar.value->size()));
}
const ArrayData& source = *datum.array();
if (!source.MayHaveNulls() ||
BitUtil::GetBit(source.buffers[0]->data(), source.offset + row)) {
const uint8_t* data = source.buffers[2]->data();
const offset_type* offsets = source.GetValues<offset_type>(1);
const offset_type offset0 = offsets[row];
const offset_type offset1 = offsets[row + 1];
return builder->Append(data + offset0, offset1 - offset0);
}
return builder->AppendNull();
}
};

struct ChooseFunction : ScalarFunction {
using ScalarFunction::ScalarFunction;

Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
// The first argument is always int64 or promoted to it. The kernel is dispatched
// based on the type of the rest of the arguments.
RETURN_NOT_OK(CheckArity(*values));
EnsureDictionaryDecoded(values);
if (values->front().type->id() != Type::INT64) {
values->front().type = int64();
}
if (auto type = CommonNumeric(values->data() + 1, values->size() - 1)) {
for (auto it = values->begin() + 1; it != values->end(); it++) {
it->type = type;
}
}
if (auto kernel = DispatchExactImpl(this, {values->back()})) return kernel;
return arrow::compute::detail::NoMatchingKernel(this, *values);
}
};

Result<ValueDescr> LastType(KernelContext*, const std::vector<ValueDescr>& descrs) {
ValueDescr result = descrs.back();
result.shape = GetBroadcastShape(descrs);
Expand Down Expand Up @@ -1652,6 +1874,26 @@ void AddPrimitiveCoalesceKernels(const std::shared_ptr<ScalarFunction>& scalar_f
}
}

void AddChooseKernel(const std::shared_ptr<ScalarFunction>& scalar_function,
detail::GetTypeId get_id, ArrayKernelExec exec) {
ScalarKernel kernel(
KernelSignature::Make({Type::INT64, InputType(get_id.id)}, OutputType(LastType),
/*is_varargs=*/true),
exec);
kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
kernel.mem_allocation = MemAllocation::PREALLOCATE;
kernel.can_write_into_slices = is_fixed_width(get_id.id);
DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
}

void AddPrimitiveChooseKernels(const std::shared_ptr<ScalarFunction>& scalar_function,
const std::vector<std::shared_ptr<DataType>>& types) {
for (auto&& type : types) {
auto exec = GenerateTypeAgnosticPrimitive<ChooseFunctor>(*type);
AddChooseKernel(scalar_function, type, std::move(exec));
}
}

const FunctionDoc if_else_doc{"Choose values based on a condition",
("`cond` must be a Boolean scalar/ array. \n`left` or "
"`right` must be of the same type scalar/ array.\n"
Expand Down Expand Up @@ -1679,6 +1921,15 @@ const FunctionDoc coalesce_doc{
"for which the value is not null. If all inputs are null in a row, the output "
"will be null."),
{"*values"}};

const FunctionDoc choose_doc{
"Given indices and arrays, choose the value from the corresponding array for each "
"index",
("For each row, the value of the first argument is used as a 0-based index into the "
"rest of the arguments (i.e. index 0 selects the second argument). The output value "
"is the corresponding value of the selected argument.\n"
"If an index is null, the output will be null."),
{"indices", "*values"}};
} // namespace

void RegisterScalarIfElse(FunctionRegistry* registry) {
Expand All @@ -1688,7 +1939,8 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {

AddPrimitiveIfElseKernels(func, NumericTypes());
AddPrimitiveIfElseKernels(func, TemporalTypes());
AddPrimitiveIfElseKernels(func, {boolean(), day_time_interval(), month_interval()});
AddPrimitiveIfElseKernels(func, IntervalTypes());
AddPrimitiveIfElseKernels(func, {boolean()});
AddNullIfElseKernel(func);
AddBinaryIfElseKernels(func, BaseBinaryTypes());
AddFSBinaryIfElseKernel(func);
Expand All @@ -1699,8 +1951,8 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
"case_when", Arity::VarArgs(/*min_args=*/1), &case_when_doc);
AddPrimitiveCaseWhenKernels(func, NumericTypes());
AddPrimitiveCaseWhenKernels(func, TemporalTypes());
AddPrimitiveCaseWhenKernels(
func, {boolean(), null(), day_time_interval(), month_interval()});
AddPrimitiveCaseWhenKernels(func, IntervalTypes());
AddPrimitiveCaseWhenKernels(func, {boolean(), null()});
AddCaseWhenKernel(func, Type::FIXED_SIZE_BINARY,
CaseWhenFunctor<FixedSizeBinaryType>::Exec);
AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor<Decimal128Type>::Exec);
Expand All @@ -1712,8 +1964,8 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
"coalesce", Arity::VarArgs(/*min_args=*/1), &coalesce_doc);
AddPrimitiveCoalesceKernels(func, NumericTypes());
AddPrimitiveCoalesceKernels(func, TemporalTypes());
AddPrimitiveCoalesceKernels(
func, {boolean(), null(), day_time_interval(), month_interval()});
AddPrimitiveCoalesceKernels(func, IntervalTypes());
AddPrimitiveCoalesceKernels(func, {boolean(), null()});
AddCoalesceKernel(func, Type::FIXED_SIZE_BINARY,
CoalesceFunctor<FixedSizeBinaryType>::Exec);
AddCoalesceKernel(func, Type::DECIMAL128, CoalesceFunctor<Decimal128Type>::Exec);
Expand All @@ -1723,6 +1975,22 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
}
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
auto func = std::make_shared<ChooseFunction>("choose", Arity::VarArgs(/*min_args=*/2),
&choose_doc);
AddPrimitiveChooseKernels(func, NumericTypes());
AddPrimitiveChooseKernels(func, TemporalTypes());
AddPrimitiveChooseKernels(func, IntervalTypes());
AddPrimitiveChooseKernels(func, {boolean(), null()});
AddChooseKernel(func, Type::FIXED_SIZE_BINARY,
ChooseFunctor<FixedSizeBinaryType>::Exec);
AddChooseKernel(func, Type::DECIMAL128, ChooseFunctor<Decimal128Type>::Exec);
AddChooseKernel(func, Type::DECIMAL256, ChooseFunctor<Decimal256Type>::Exec);
for (const auto& ty : BaseBinaryTypes()) {
AddChooseKernel(func, ty, GenerateTypeAgnosticVarBinaryBase<ChooseFunctor>(ty));
}
DCHECK_OK(registry->AddFunction(std::move(func)));
}
}

} // namespace internal
Expand Down
Loading