Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ static auto kRankOptionsType = GetFunctionOptionsType<RankOptions>(
DataMember("tiebreaker", &RankOptions::tiebreaker));
static auto kPairwiseOptionsType = GetFunctionOptionsType<PairwiseOptions>(
DataMember("periods", &PairwiseOptions::periods));
static auto kListFlattenOptionsType = GetFunctionOptionsType<ListFlattenOptions>(
DataMember("recursive", &ListFlattenOptions::recursive));
} // namespace
} // namespace internal

Expand Down Expand Up @@ -224,6 +226,10 @@ PairwiseOptions::PairwiseOptions(int64_t periods)
: FunctionOptions(internal::kPairwiseOptionsType), periods(periods) {}
constexpr char PairwiseOptions::kTypeName[];

ListFlattenOptions::ListFlattenOptions(bool recursive)
: FunctionOptions(internal::kListFlattenOptionsType), recursive(recursive) {}
constexpr char ListFlattenOptions::kTypeName[];

namespace internal {
void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType));
Expand All @@ -237,6 +243,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kCumulativeOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPairwiseOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kListFlattenOptionsType));
}
} // namespace internal

Expand Down
12 changes: 12 additions & 0 deletions cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,18 @@ class ARROW_EXPORT PairwiseOptions : public FunctionOptions {
int64_t periods = 1;
};

/// \brief Options for list_flatten function
class ARROW_EXPORT ListFlattenOptions : public FunctionOptions {
public:
explicit ListFlattenOptions(bool recursive = false);
static constexpr char const kTypeName[] = "ListFlattenOptions";
static ListFlattenOptions Defaults() { return ListFlattenOptions(); }

/// \brief If true, the list is flattened recursively until a non-list
/// array is formed.
bool recursive = false;
};

/// @}

/// \brief Filter with a boolean selection filter
Expand Down
21 changes: 18 additions & 3 deletions cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <mutex>
#include <vector>

#include "arrow/compute/api_vector.h"
#include "arrow/type_fwd.h"

namespace arrow {
Expand Down Expand Up @@ -56,9 +57,23 @@ Result<TypeHolder> LastType(KernelContext*, const std::vector<TypeHolder>& types
return types.back();
}

Result<TypeHolder> ListValuesType(KernelContext*, const std::vector<TypeHolder>& args) {
const auto& list_type = checked_cast<const BaseListType&>(*args[0].type);
return list_type.value_type().get();
Result<TypeHolder> ListValuesType(KernelContext* ctx,
const std::vector<TypeHolder>& args) {
auto list_type = checked_cast<const BaseListType*>(args[0].type);
auto value_type = list_type->value_type().get();

auto recursive =
ctx->state() ? OptionsWrapper<ListFlattenOptions>::Get(ctx).recursive : false;
if (!recursive) {
return value_type;
}

for (auto value_kind = value_type->id();
is_list(value_kind) || is_list_view(value_kind); value_kind = value_type->id()) {
list_type = checked_cast<const BaseListType*>(list_type->value_type().get());
value_type = list_type->value_type().get();
}
return value_type;
}

void EnsureDictionaryDecoded(std::vector<TypeHolder>* types) {
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,8 @@ static void VisitTwoArrayValuesInline(const ArraySpan& arr0, const ArraySpan& ar

Result<TypeHolder> FirstType(KernelContext*, const std::vector<TypeHolder>& types);
Result<TypeHolder> LastType(KernelContext*, const std::vector<TypeHolder>& types);
Result<TypeHolder> ListValuesType(KernelContext*, const std::vector<TypeHolder>& types);
Result<TypeHolder> ListValuesType(KernelContext* ctx,
const std::vector<TypeHolder>& types);

// ----------------------------------------------------------------------
// Helpers for iterating over common DataType instances for adding kernels to
Expand Down
49 changes: 39 additions & 10 deletions cpp/src/arrow/compute/kernels/scalar_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "arrow/compute/api_scalar.h"
#include "arrow/compute/kernels/common_internal.h"
#include "arrow/result.h"
#include "arrow/type_fwd.h"
#include "arrow/util/bit_block_counter.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/bitmap_generate.h"
Expand All @@ -41,10 +42,17 @@ Status ListValueLength(KernelContext* ctx, const ExecSpan& batch, ExecResult* ou
const ArraySpan& arr = batch[0].array;
ArraySpan* out_arr = out->array_span_mutable();
auto out_values = out_arr->GetValues<offset_type>(1);
const offset_type* offsets = arr.GetValues<offset_type>(1);
// Offsets are always well-defined and monotonic, even for null values
for (int64_t i = 0; i < arr.length; ++i) {
*out_values++ = offsets[i + 1] - offsets[i];
if (is_list_view(*arr.type)) {
const auto* sizes = arr.GetValues<offset_type>(2);
if (arr.length > 0) {
memcpy(out_values, sizes, arr.length * sizeof(offset_type));
}
} else {
const offset_type* offsets = arr.GetValues<offset_type>(1);
// Offsets are always well-defined and monotonic, even for null values
for (int64_t i = 0; i < arr.length; ++i) {
*out_values++ = offsets[i + 1] - offsets[i];
}
}
return Status::OK();
}
Expand All @@ -59,6 +67,30 @@ Status FixedSizeListValueLength(KernelContext* ctx, const ExecSpan& batch,
return Status::OK();
}

template <typename InListType>
void AddListValueLengthKernel(ScalarFunction* func,
const std::shared_ptr<DataType>& out_type) {
auto in_type = {InputType(InListType::type_id)};
ScalarKernel kernel(in_type, out_type, ListValueLength<InListType>);
DCHECK_OK(func->AddKernel(std::move(kernel)));
}

template <>
void AddListValueLengthKernel<FixedSizeListType>(
ScalarFunction* func, const std::shared_ptr<DataType>& out_type) {
auto in_type = {InputType(Type::FIXED_SIZE_LIST)};
ScalarKernel kernel(in_type, out_type, FixedSizeListValueLength);
DCHECK_OK(func->AddKernel(std::move(kernel)));
}

void AddListValueLengthKernels(ScalarFunction* func) {
AddListValueLengthKernel<ListType>(func, int32());
AddListValueLengthKernel<LargeListType>(func, int64());
AddListValueLengthKernel<ListViewType>(func, int32());
AddListValueLengthKernel<LargeListViewType>(func, int64());
AddListValueLengthKernel<FixedSizeListType>(func, int32());
}

const FunctionDoc list_value_length_doc{
"Compute list lengths",
("`lists` must have a list-like type.\n"
Expand Down Expand Up @@ -399,6 +431,8 @@ void AddListElementKernels(ScalarFunction* func) {
void AddListElementKernels(ScalarFunction* func) {
AddListElementKernels<ListType, ListElement>(func);
AddListElementKernels<LargeListType, ListElement>(func);
AddListElementKernels<ListViewType, ListElement>(func);
AddListElementKernels<LargeListViewType, ListElement>(func);
AddListElementKernels<FixedSizeListType, FixedSizeListElement>(func);
}

Expand Down Expand Up @@ -824,12 +858,7 @@ const FunctionDoc map_lookup_doc{
void RegisterScalarNested(FunctionRegistry* registry) {
auto list_value_length = std::make_shared<ScalarFunction>(
"list_value_length", Arity::Unary(), list_value_length_doc);
DCHECK_OK(list_value_length->AddKernel({InputType(Type::LIST)}, int32(),
ListValueLength<ListType>));
DCHECK_OK(list_value_length->AddKernel({InputType(Type::FIXED_SIZE_LIST)}, int32(),
FixedSizeListValueLength));
DCHECK_OK(list_value_length->AddKernel({InputType(Type::LARGE_LIST)}, int64(),
ListValueLength<LargeListType>));
AddListValueLengthKernels(list_value_length.get());
DCHECK_OK(registry->AddFunction(std::move(list_value_length)));

auto list_element =
Expand Down
17 changes: 14 additions & 3 deletions cpp/src/arrow/compute/kernels/scalar_nested_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,21 @@ namespace arrow {
namespace compute {

static std::shared_ptr<DataType> GetOffsetType(const DataType& type) {
return type.id() == Type::LIST ? int32() : int64();
switch (type.id()) {
case Type::LIST:
case Type::LIST_VIEW:
return int32();
case Type::LARGE_LIST:
case Type::LARGE_LIST_VIEW:
return int64();
default:
Unreachable("Unexpected type");
}
}

TEST(TestScalarNested, ListValueLength) {
for (auto ty : {list(int32()), large_list(int32())}) {
for (auto ty : {list(int32()), large_list(int32()), list_view(int32()),
large_list_view(int32())}) {
CheckScalarUnary("list_value_length", ty, "[[0, null, 1], null, [2, 3], []]",
GetOffsetType(*ty), "[3, null, 2, 0]");
}
Expand All @@ -47,7 +57,8 @@ TEST(TestScalarNested, ListValueLength) {
TEST(TestScalarNested, ListElementNonFixedListWithNulls) {
auto sample = "[[7, 5, 81], [6, null, 4, 7, 8], [3, 12, 2, 0], [1, 9], null]";
for (auto ty : NumericTypes()) {
for (auto list_type : {list(ty), large_list(ty)}) {
for (auto list_type :
{list(ty), large_list(ty), list_view(ty), large_list_view(ty)}) {
auto input = ArrayFromJSON(list_type, sample);
auto null_input = ArrayFromJSON(list_type, "[null]");
for (auto index_type : IntTypes()) {
Expand Down
54 changes: 41 additions & 13 deletions cpp/src/arrow/compute/kernels/vector_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
// Vector kernels involving nested types

#include "arrow/array/array_base.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/kernels/common_internal.h"
#include "arrow/result.h"
#include "arrow/visit_type_inline.h"
Expand All @@ -29,8 +30,13 @@ namespace {

template <typename Type>
Status ListFlatten(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
auto recursive = OptionsWrapper<ListFlattenOptions>::Get(ctx).recursive;
typename TypeTraits<Type>::ArrayType list_array(batch[0].array.ToArrayData());
ARROW_ASSIGN_OR_RAISE(auto result, list_array.Flatten(ctx->memory_pool()));

auto pool = ctx->memory_pool();
ARROW_ASSIGN_OR_RAISE(auto result, (recursive ? list_array.FlattenRecursively(pool)
: list_array.Flatten(pool)));

out->value = std::move(result->data());
return Status::OK();
}
Expand Down Expand Up @@ -107,10 +113,15 @@ struct ListParentIndicesArray {

const FunctionDoc list_flatten_doc(
"Flatten list values",
("`lists` must have a list-like type.\n"
"Return an array with the top list level flattened.\n"
"Top-level null values in `lists` do not emit anything in the input."),
{"lists"});
("`lists` must have a list-like type (lists, list-views, and\n"
"fixed-size lists).\n"
"Return an array with the top list level flattened unless\n"
"`recursive` is set to true in ListFlattenOptions. When that\n"
"is that case, flattening happens recursively until a non-list\n"
"array is formed.\n"
"\n"
"Null list values do not emit anything to the output."),
{"lists"}, "ListFlattenOptions");

const FunctionDoc list_parent_indices_doc(
"Compute parent indices of nested list values",
Expand Down Expand Up @@ -153,17 +164,34 @@ class ListParentIndicesFunction : public MetaFunction {
}
};

const ListFlattenOptions* GetDefaultListFlattenOptions() {
static const auto kDefaultListFlattenOptions = ListFlattenOptions::Defaults();
return &kDefaultListFlattenOptions;
}

template <typename InListType>
void AddBaseListFlattenKernels(VectorFunction* func) {
auto in_type = {InputType(InListType::type_id)};
auto out_type = OutputType(ListValuesType);
VectorKernel kernel(in_type, out_type, ListFlatten<InListType>,
OptionsWrapper<ListFlattenOptions>::Init);
DCHECK_OK(func->AddKernel(std::move(kernel)));
}

void AddBaseListFlattenKernels(VectorFunction* func) {
AddBaseListFlattenKernels<ListType>(func);
AddBaseListFlattenKernels<LargeListType>(func);
AddBaseListFlattenKernels<FixedSizeListType>(func);
AddBaseListFlattenKernels<ListViewType>(func);
AddBaseListFlattenKernels<LargeListViewType>(func);
}

} // namespace

void RegisterVectorNested(FunctionRegistry* registry) {
auto flatten =
std::make_shared<VectorFunction>("list_flatten", Arity::Unary(), list_flatten_doc);
DCHECK_OK(flatten->AddKernel({Type::LIST}, OutputType(ListValuesType),
ListFlatten<ListType>));
DCHECK_OK(flatten->AddKernel({Type::FIXED_SIZE_LIST}, OutputType(ListValuesType),
ListFlatten<FixedSizeListType>));
DCHECK_OK(flatten->AddKernel({Type::LARGE_LIST}, OutputType(ListValuesType),
ListFlatten<LargeListType>));
auto flatten = std::make_shared<VectorFunction>(
"list_flatten", Arity::Unary(), list_flatten_doc, GetDefaultListFlattenOptions());
AddBaseListFlattenKernels(flatten.get());
DCHECK_OK(registry->AddFunction(std::move(flatten)));

DCHECK_OK(registry->AddFunction(std::make_shared<ListParentIndicesFunction>()));
Expand Down
Loading