Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ Result<ValueDescr> FirstType(KernelContext*, const std::vector<ValueDescr>& desc
return result;
}

Result<ValueDescr> ListValuesType(KernelContext*, const std::vector<ValueDescr>& args) {
const auto& list_type = checked_cast<const BaseListType&>(*args[0].type);
return ValueDescr(list_type.value_type(), GetBroadcastShape(args));
}

void EnsureDictionaryDecoded(std::vector<ValueDescr>* descrs) {
for (ValueDescr& descr : *descrs) {
if (descr.type->id() == Type::DICTIONARY) {
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ static void VisitTwoArrayValuesInline(const ArrayData& arr0, const ArrayData& ar
// Reusable type resolvers

Result<ValueDescr> FirstType(KernelContext*, const std::vector<ValueDescr>& descrs);
Result<ValueDescr> ListValuesType(KernelContext*, const std::vector<ValueDescr>& args);

// ----------------------------------------------------------------------
// Generate an array kernel given template classes
Expand Down
114 changes: 114 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,114 @@ const FunctionDoc list_value_length_doc{
"Null values emit a null in the output."),
{"lists"}};

template <typename Type, typename IndexType>
struct ListElementArray {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
using ListArrayType = typename TypeTraits<Type>::ArrayType;
using IndexScalarType = typename TypeTraits<IndexType>::ScalarType;
const auto& index_scalar = batch[1].scalar_as<IndexScalarType>();
if (ARROW_PREDICT_FALSE(!index_scalar.is_valid)) {
return Status::Invalid("Index must not be null");
}
ListArrayType list_array(batch[0].array());
auto index = index_scalar.value;
if (ARROW_PREDICT_FALSE(index < 0)) {
return Status::Invalid("Index ", index,
" is out of bounds: should be greater than or equal to 0");
}
std::unique_ptr<ArrayBuilder> builder;
RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), list_array.value_type(), &builder));
RETURN_NOT_OK(builder->Reserve(list_array.length()));
for (int i = 0; i < list_array.length(); ++i) {
if (list_array.IsNull(i)) {
RETURN_NOT_OK(builder->AppendNull());
continue;
}
std::shared_ptr<arrow::Array> value_array = list_array.value_slice(i);
auto len = value_array->length();
if (ARROW_PREDICT_FALSE(index >= static_cast<typename IndexType::c_type>(len))) {
return Status::Invalid("Index ", index, " is out of bounds: should be in [0, ",
len, ")");
}
RETURN_NOT_OK(builder->AppendArraySlice(*value_array->data(), index, 1));
}
ARROW_ASSIGN_OR_RAISE(auto result, builder->Finish());
out->value = result->data();
return Status::OK();
}
};

template <typename, typename IndexType>
struct ListElementScalar {
static Status Exec(KernelContext* /*ctx*/, const ExecBatch& batch, Datum* out) {
using IndexScalarType = typename TypeTraits<IndexType>::ScalarType;
const auto& index_scalar = batch[1].scalar_as<IndexScalarType>();
if (ARROW_PREDICT_FALSE(!index_scalar.is_valid)) {
return Status::Invalid("Index must not be null");
}
const auto& list_scalar = batch[0].scalar_as<BaseListScalar>();
if (ARROW_PREDICT_FALSE(!list_scalar.is_valid)) {
out->value = MakeNullScalar(
checked_cast<const BaseListType&>(*batch[0].type()).value_type());
return Status::OK();
}
auto list = list_scalar.value;
auto index = index_scalar.value;
auto len = list->length();
if (ARROW_PREDICT_FALSE(index < 0 ||
index >= static_cast<typename IndexType::c_type>(len))) {
return Status::Invalid("Index ", index, " is out of bounds: should be in [0, ", len,
")");
}
ARROW_ASSIGN_OR_RAISE(out->value, list->GetScalar(index));
return Status::OK();
}
};

template <typename InListType>
void AddListElementArrayKernels(ScalarFunction* func) {
for (const auto& index_type : IntTypes()) {
auto inputs = {InputType::Array(InListType::type_id), InputType::Scalar(index_type)};
auto output = OutputType{ListValuesType};
auto sig = KernelSignature::Make(std::move(inputs), std::move(output),
/*is_varargs=*/false);
auto scalar_exec = GenerateInteger<ListElementArray, InListType>({index_type->id()});
ScalarKernel kernel{std::move(sig), std::move(scalar_exec)};
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
DCHECK_OK(func->AddKernel(std::move(kernel)));
}
}

void AddListElementArrayKernels(ScalarFunction* func) {
AddListElementArrayKernels<ListType>(func);
AddListElementArrayKernels<LargeListType>(func);
AddListElementArrayKernels<FixedSizeListType>(func);
}

void AddListElementScalarKernels(ScalarFunction* func) {
for (const auto list_type_id : {Type::LIST, Type::LARGE_LIST, Type::FIXED_SIZE_LIST}) {
for (const auto& index_type : IntTypes()) {
auto inputs = {InputType::Scalar(list_type_id), InputType::Scalar(index_type)};
auto output = OutputType{ListValuesType};
auto sig = KernelSignature::Make(std::move(inputs), std::move(output),
/*is_varargs=*/false);
auto scalar_exec = GenerateInteger<ListElementScalar, void>({index_type->id()});
ScalarKernel kernel{std::move(sig), std::move(scalar_exec)};
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
DCHECK_OK(func->AddKernel(std::move(kernel)));
}
}
}

const FunctionDoc list_element_doc(
"Compute elements using of nested list values using an index",
("`lists` must have a list-like type.\n"
"For each value in each list of `lists`, the element at `index`\n"
"is emitted. Null values emit a null in the output."),
{"lists", "index"});

Result<ValueDescr> MakeStructResolve(KernelContext* ctx,
const std::vector<ValueDescr>& descrs) {
auto names = OptionsWrapper<MakeStructOptions>::Get(ctx).field_names;
Expand Down Expand Up @@ -185,6 +293,12 @@ void RegisterScalarNested(FunctionRegistry* registry) {
ListValueLength<LargeListType>));
DCHECK_OK(registry->AddFunction(std::move(list_value_length)));

auto list_element = std::make_shared<ScalarFunction>("list_element", Arity::Binary(),
&list_element_doc);
AddListElementArrayKernels(list_element.get());
AddListElementScalarKernels(list_element.get());
DCHECK_OK(registry->AddFunction(std::move(list_element)));

static MakeStructOptions kDefaultMakeStructOptions;
auto make_struct_function = std::make_shared<ScalarFunction>(
"make_struct", Arity::VarArgs(), &make_struct_doc, &kDefaultMakeStructOptions);
Expand Down
64 changes: 64 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_nested_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,70 @@ TEST(TestScalarNested, ListValueLength) {
"[3, null, 3, 3]");
}

TEST(TestScalarNested, ListElementNonFixedListWithNulls) {
auto sample = "[[7, 5, 81], [6, null, 4, 7, 8], [3, 12, 2, 0], [1, 9], null]";
for (auto ty : NumericTypes()) {
for (auto list_type : {list(ty), large_list(ty)}) {
auto input = ArrayFromJSON(list_type, sample);
auto null_input = ArrayFromJSON(list_type, "[null]");
for (auto index_type : IntTypes()) {
auto index = ScalarFromJSON(index_type, "1");
auto expected = ArrayFromJSON(ty, "[5, null, 12, 9, null]");
auto expected_null = ArrayFromJSON(ty, "[null]");
CheckScalar("list_element", {input, index}, expected);
CheckScalar("list_element", {null_input, index}, expected_null);
}
}
}
}

TEST(TestScalarNested, ListElementFixedList) {
auto sample = "[[7, 5, 81], [6, 4, 8], [3, 12, 2], [1, 43, 87]]";
for (auto ty : NumericTypes()) {
auto input = ArrayFromJSON(fixed_size_list(ty, 3), sample);
for (auto index_type : IntTypes()) {
auto index = ScalarFromJSON(index_type, "0");
auto expected = ArrayFromJSON(ty, "[7, 6, 3, 1]");
CheckScalar("list_element", {input, index}, expected);
}
}
}

TEST(TestScalarNested, ListElementInvalid) {
auto input_array = ArrayFromJSON(list(float32()), "[[0.1, 1.1], [0.2, 1.2]]");
auto input_scalar = ScalarFromJSON(list(float32()), "[0.1, 0.2]");

// invalid index: null
auto index = ScalarFromJSON(int32(), "null");
EXPECT_THAT(CallFunction("list_element", {input_array, index}),
Raises(StatusCode::Invalid));
EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
Raises(StatusCode::Invalid));

// invalid index: < 0
index = ScalarFromJSON(int32(), "-1");
EXPECT_THAT(CallFunction("list_element", {input_array, index}),
Raises(StatusCode::Invalid));
EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
Raises(StatusCode::Invalid));

// invalid index: >= list.length
index = ScalarFromJSON(int32(), "2");
EXPECT_THAT(CallFunction("list_element", {input_array, index}),
Raises(StatusCode::Invalid));
EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
Raises(StatusCode::Invalid));

// invalid input
input_array = ArrayFromJSON(list(float32()), "[[41, 6, 93], [], [2]]");
input_scalar = ScalarFromJSON(list(float32()), "[]");
index = ScalarFromJSON(int32(), "0");
EXPECT_THAT(CallFunction("list_element", {input_array, index}),
Raises(StatusCode::Invalid));
EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
Raises(StatusCode::Invalid));
}

struct {
Result<Datum> operator()(std::vector<Datum> args) {
return CallFunction("make_struct", args);
Expand Down
5 changes: 0 additions & 5 deletions cpp/src/arrow/compute/kernels/vector_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,6 @@ struct ListParentIndicesArray {
}
};

Result<ValueDescr> ListValuesType(KernelContext*, const std::vector<ValueDescr>& args) {
const auto& list_type = checked_cast<const BaseListType&>(*args[0].type);
return ValueDescr::Array(list_type.value_type());
}

Result<std::shared_ptr<DataType>> ListParentIndicesType(const DataType& input_type) {
switch (input_type.id()) {
case Type::LIST:
Expand Down
25 changes: 15 additions & 10 deletions docs/source/cpp/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1492,19 +1492,24 @@ value, but smaller than nulls.
Structural transforms
~~~~~~~~~~~~~~~~~~~~~

+--------------------------+------------+--------------------+---------------------+---------+
| Function name | Arity | Input types | Output type | Notes |
+==========================+============+====================+=====================+=========+
| list_flatten | Unary | List-like | List value type | \(1) |
+--------------------------+------------+--------------------+---------------------+---------+
| list_parent_indices | Unary | List-like | Int32 or Int64 | \(2) |
+--------------------------+------------+--------------------+---------------------+---------+

* \(1) The top level of nesting is removed: all values in the list child array,
+--------------------------+------------+------------------------------------+---------------------+---------+
| Function name | Arity | Input types | Output type | Notes |
+==========================+============+====================================+=====================+=========+
| list_element | Binary | List-like (Arg 0), Integral (Arg 1)| List value type | \(1) |
+--------------------------+------------+------------------------------------+---------------------+---------+
| list_flatten | Unary | List-like | List value type | \(2) |
+--------------------------+------------+------------------------------------+---------------------+---------+
| list_parent_indices | Unary | List-like | Int32 or Int64 | \(3) |
+--------------------------+------------+------------------------------------+---------------------+---------+

* \(1) Output is an array of the same length as the input list array. The
output values are the values at the specified index of each child list.

* \(2) The top level of nesting is removed: all values in the list child array,
including nulls, are appended to the output. However, nulls in the parent
list array are discarded.

* \(2) For each value in the list child array, the index at which it is found
* \(3) For each value in the list child array, the index at which it is found
in the list array is appended to the output. Nulls in the parent list array
are discarded. Output type is Int32 for List and FixedSizeList, Int64 for
LargeList.
Expand Down
3 changes: 2 additions & 1 deletion docs/source/python/api/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ Structural Transforms
is_nan
is_null
is_valid
list_value_length
list_element
list_flatten
list_parent_indices
list_value_length
18 changes: 18 additions & 0 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2133,3 +2133,21 @@ def test_case_when():
[False, True, None]),
[1, 2, 3],
[11, 12, 13]) == pa.array([1, 12, None])


def test_list_element():
element_type = pa.struct([('a', pa.float64()), ('b', pa.int8())])
list_type = pa.list_(element_type)
l1 = [{'a': .4, 'b': 2}, None, {'a': .2, 'b': 4}, None, {'a': 5.6, 'b': 6}]
l2 = [None, {'a': .52, 'b': 3}, {'a': .7, 'b': 4}, None, {'a': .6, 'b': 8}]
lists = pa.array([l1, l2], list_type)

index = 1
result = pa.compute.list_element(lists, index)
expected = pa.array([None, {'a': 0.52, 'b': 3}], element_type)
assert result.equals(expected)

index = 4
result = pa.compute.list_element(lists, index)
expected = pa.array([{'a': 5.6, 'b': 6}, {'a': .6, 'b': 8}], element_type)
assert result.equals(expected)