From 2369457164c5b5f8beeb1b2e31a5dde57515bdd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Percy=20Camilo=20Trive=C3=B1o=20Aucahuasi?= Date: Tue, 14 Sep 2021 17:13:15 -0500 Subject: [PATCH] implement new vector function: list_element ARROW-12669 add python test and C++ docs for the new function: list_element convert list_element into a scalar function minor changes format support scalar inputs for list_element, improve tests minor changes less generated code thanks to some template tricks less generated code, again using template tricks --- .../arrow/compute/kernels/codegen_internal.cc | 5 + .../arrow/compute/kernels/codegen_internal.h | 1 + .../arrow/compute/kernels/scalar_nested.cc | 114 ++++++++++++++++++ .../compute/kernels/scalar_nested_test.cc | 64 ++++++++++ .../arrow/compute/kernels/vector_nested.cc | 5 - docs/source/cpp/compute.rst | 25 ++-- docs/source/python/api/compute.rst | 3 +- python/pyarrow/tests/test_compute.py | 18 +++ 8 files changed, 219 insertions(+), 16 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index 5ac32c044bb..fe4b593b481 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -66,6 +66,11 @@ Result FirstType(KernelContext*, const std::vector& desc return result; } +Result ListValuesType(KernelContext*, const std::vector& args) { + const auto& list_type = checked_cast(*args[0].type); + return ValueDescr(list_type.value_type(), GetBroadcastShape(args)); +} + void EnsureDictionaryDecoded(std::vector* descrs) { for (ValueDescr& descr : *descrs) { if (descr.type->id() == Type::DICTIONARY) { diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 6a7261eb653..f9ce34b06e0 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -395,6 +395,7 @@ static void VisitTwoArrayValuesInline(const ArrayData& arr0, const ArrayData& ar // Reusable type resolvers Result FirstType(KernelContext*, const std::vector& descrs); +Result ListValuesType(KernelContext*, const std::vector& args); // ---------------------------------------------------------------------- // Generate an array kernel given template classes diff --git a/cpp/src/arrow/compute/kernels/scalar_nested.cc b/cpp/src/arrow/compute/kernels/scalar_nested.cc index 9ffe8bf1587..aafaeb34159 100644 --- a/cpp/src/arrow/compute/kernels/scalar_nested.cc +++ b/cpp/src/arrow/compute/kernels/scalar_nested.cc @@ -80,6 +80,114 @@ const FunctionDoc list_value_length_doc{ "Null values emit a null in the output."), {"lists"}}; +template +struct ListElementArray { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + using ListArrayType = typename TypeTraits::ArrayType; + using IndexScalarType = typename TypeTraits::ScalarType; + const auto& index_scalar = batch[1].scalar_as(); + if (ARROW_PREDICT_FALSE(!index_scalar.is_valid)) { + return Status::Invalid("Index must not be null"); + } + ListArrayType list_array(batch[0].array()); + auto index = index_scalar.value; + if (ARROW_PREDICT_FALSE(index < 0)) { + return Status::Invalid("Index ", index, + " is out of bounds: should be greater than or equal to 0"); + } + std::unique_ptr builder; + RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), list_array.value_type(), &builder)); + RETURN_NOT_OK(builder->Reserve(list_array.length())); + for (int i = 0; i < list_array.length(); ++i) { + if (list_array.IsNull(i)) { + RETURN_NOT_OK(builder->AppendNull()); + continue; + } + std::shared_ptr value_array = list_array.value_slice(i); + auto len = value_array->length(); + if (ARROW_PREDICT_FALSE(index >= static_cast(len))) { + return Status::Invalid("Index ", index, " is out of bounds: should be in [0, ", + len, ")"); + } + RETURN_NOT_OK(builder->AppendArraySlice(*value_array->data(), index, 1)); + } + ARROW_ASSIGN_OR_RAISE(auto result, builder->Finish()); + out->value = result->data(); + return Status::OK(); + } +}; + +template +struct ListElementScalar { + static Status Exec(KernelContext* /*ctx*/, const ExecBatch& batch, Datum* out) { + using IndexScalarType = typename TypeTraits::ScalarType; + const auto& index_scalar = batch[1].scalar_as(); + if (ARROW_PREDICT_FALSE(!index_scalar.is_valid)) { + return Status::Invalid("Index must not be null"); + } + const auto& list_scalar = batch[0].scalar_as(); + if (ARROW_PREDICT_FALSE(!list_scalar.is_valid)) { + out->value = MakeNullScalar( + checked_cast(*batch[0].type()).value_type()); + return Status::OK(); + } + auto list = list_scalar.value; + auto index = index_scalar.value; + auto len = list->length(); + if (ARROW_PREDICT_FALSE(index < 0 || + index >= static_cast(len))) { + return Status::Invalid("Index ", index, " is out of bounds: should be in [0, ", len, + ")"); + } + ARROW_ASSIGN_OR_RAISE(out->value, list->GetScalar(index)); + return Status::OK(); + } +}; + +template +void AddListElementArrayKernels(ScalarFunction* func) { + for (const auto& index_type : IntTypes()) { + auto inputs = {InputType::Array(InListType::type_id), InputType::Scalar(index_type)}; + auto output = OutputType{ListValuesType}; + auto sig = KernelSignature::Make(std::move(inputs), std::move(output), + /*is_varargs=*/false); + auto scalar_exec = GenerateInteger({index_type->id()}); + ScalarKernel kernel{std::move(sig), std::move(scalar_exec)}; + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + DCHECK_OK(func->AddKernel(std::move(kernel))); + } +} + +void AddListElementArrayKernels(ScalarFunction* func) { + AddListElementArrayKernels(func); + AddListElementArrayKernels(func); + AddListElementArrayKernels(func); +} + +void AddListElementScalarKernels(ScalarFunction* func) { + for (const auto list_type_id : {Type::LIST, Type::LARGE_LIST, Type::FIXED_SIZE_LIST}) { + for (const auto& index_type : IntTypes()) { + auto inputs = {InputType::Scalar(list_type_id), InputType::Scalar(index_type)}; + auto output = OutputType{ListValuesType}; + auto sig = KernelSignature::Make(std::move(inputs), std::move(output), + /*is_varargs=*/false); + auto scalar_exec = GenerateInteger({index_type->id()}); + ScalarKernel kernel{std::move(sig), std::move(scalar_exec)}; + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + DCHECK_OK(func->AddKernel(std::move(kernel))); + } + } +} + +const FunctionDoc list_element_doc( + "Compute elements using of nested list values using an index", + ("`lists` must have a list-like type.\n" + "For each value in each list of `lists`, the element at `index`\n" + "is emitted. Null values emit a null in the output."), + {"lists", "index"}); + Result MakeStructResolve(KernelContext* ctx, const std::vector& descrs) { auto names = OptionsWrapper::Get(ctx).field_names; @@ -185,6 +293,12 @@ void RegisterScalarNested(FunctionRegistry* registry) { ListValueLength)); DCHECK_OK(registry->AddFunction(std::move(list_value_length))); + auto list_element = std::make_shared("list_element", Arity::Binary(), + &list_element_doc); + AddListElementArrayKernels(list_element.get()); + AddListElementScalarKernels(list_element.get()); + DCHECK_OK(registry->AddFunction(std::move(list_element))); + static MakeStructOptions kDefaultMakeStructOptions; auto make_struct_function = std::make_shared( "make_struct", Arity::VarArgs(), &make_struct_doc, &kDefaultMakeStructOptions); diff --git a/cpp/src/arrow/compute/kernels/scalar_nested_test.cc b/cpp/src/arrow/compute/kernels/scalar_nested_test.cc index 5cabed35406..cb16257399d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_nested_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_nested_test.cc @@ -43,6 +43,70 @@ TEST(TestScalarNested, ListValueLength) { "[3, null, 3, 3]"); } +TEST(TestScalarNested, ListElementNonFixedListWithNulls) { + auto sample = "[[7, 5, 81], [6, null, 4, 7, 8], [3, 12, 2, 0], [1, 9], null]"; + for (auto ty : NumericTypes()) { + for (auto list_type : {list(ty), large_list(ty)}) { + auto input = ArrayFromJSON(list_type, sample); + auto null_input = ArrayFromJSON(list_type, "[null]"); + for (auto index_type : IntTypes()) { + auto index = ScalarFromJSON(index_type, "1"); + auto expected = ArrayFromJSON(ty, "[5, null, 12, 9, null]"); + auto expected_null = ArrayFromJSON(ty, "[null]"); + CheckScalar("list_element", {input, index}, expected); + CheckScalar("list_element", {null_input, index}, expected_null); + } + } + } +} + +TEST(TestScalarNested, ListElementFixedList) { + auto sample = "[[7, 5, 81], [6, 4, 8], [3, 12, 2], [1, 43, 87]]"; + for (auto ty : NumericTypes()) { + auto input = ArrayFromJSON(fixed_size_list(ty, 3), sample); + for (auto index_type : IntTypes()) { + auto index = ScalarFromJSON(index_type, "0"); + auto expected = ArrayFromJSON(ty, "[7, 6, 3, 1]"); + CheckScalar("list_element", {input, index}, expected); + } + } +} + +TEST(TestScalarNested, ListElementInvalid) { + auto input_array = ArrayFromJSON(list(float32()), "[[0.1, 1.1], [0.2, 1.2]]"); + auto input_scalar = ScalarFromJSON(list(float32()), "[0.1, 0.2]"); + + // invalid index: null + auto index = ScalarFromJSON(int32(), "null"); + EXPECT_THAT(CallFunction("list_element", {input_array, index}), + Raises(StatusCode::Invalid)); + EXPECT_THAT(CallFunction("list_element", {input_scalar, index}), + Raises(StatusCode::Invalid)); + + // invalid index: < 0 + index = ScalarFromJSON(int32(), "-1"); + EXPECT_THAT(CallFunction("list_element", {input_array, index}), + Raises(StatusCode::Invalid)); + EXPECT_THAT(CallFunction("list_element", {input_scalar, index}), + Raises(StatusCode::Invalid)); + + // invalid index: >= list.length + index = ScalarFromJSON(int32(), "2"); + EXPECT_THAT(CallFunction("list_element", {input_array, index}), + Raises(StatusCode::Invalid)); + EXPECT_THAT(CallFunction("list_element", {input_scalar, index}), + Raises(StatusCode::Invalid)); + + // invalid input + input_array = ArrayFromJSON(list(float32()), "[[41, 6, 93], [], [2]]"); + input_scalar = ScalarFromJSON(list(float32()), "[]"); + index = ScalarFromJSON(int32(), "0"); + EXPECT_THAT(CallFunction("list_element", {input_array, index}), + Raises(StatusCode::Invalid)); + EXPECT_THAT(CallFunction("list_element", {input_scalar, index}), + Raises(StatusCode::Invalid)); +} + struct { Result operator()(std::vector args) { return CallFunction("make_struct", args); diff --git a/cpp/src/arrow/compute/kernels/vector_nested.cc b/cpp/src/arrow/compute/kernels/vector_nested.cc index 974d0f0e779..f4c61ba7472 100644 --- a/cpp/src/arrow/compute/kernels/vector_nested.cc +++ b/cpp/src/arrow/compute/kernels/vector_nested.cc @@ -110,11 +110,6 @@ struct ListParentIndicesArray { } }; -Result ListValuesType(KernelContext*, const std::vector& args) { - const auto& list_type = checked_cast(*args[0].type); - return ValueDescr::Array(list_type.value_type()); -} - Result> ListParentIndicesType(const DataType& input_type) { switch (input_type.id()) { case Type::LIST: diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 9db9c5bc563..b10c7a120b2 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -1492,19 +1492,24 @@ value, but smaller than nulls. Structural transforms ~~~~~~~~~~~~~~~~~~~~~ -+--------------------------+------------+--------------------+---------------------+---------+ -| Function name | Arity | Input types | Output type | Notes | -+==========================+============+====================+=====================+=========+ -| list_flatten | Unary | List-like | List value type | \(1) | -+--------------------------+------------+--------------------+---------------------+---------+ -| list_parent_indices | Unary | List-like | Int32 or Int64 | \(2) | -+--------------------------+------------+--------------------+---------------------+---------+ - -* \(1) The top level of nesting is removed: all values in the list child array, ++--------------------------+------------+------------------------------------+---------------------+---------+ +| Function name | Arity | Input types | Output type | Notes | ++==========================+============+====================================+=====================+=========+ +| list_element | Binary | List-like (Arg 0), Integral (Arg 1)| List value type | \(1) | ++--------------------------+------------+------------------------------------+---------------------+---------+ +| list_flatten | Unary | List-like | List value type | \(2) | ++--------------------------+------------+------------------------------------+---------------------+---------+ +| list_parent_indices | Unary | List-like | Int32 or Int64 | \(3) | ++--------------------------+------------+------------------------------------+---------------------+---------+ + +* \(1) Output is an array of the same length as the input list array. The + output values are the values at the specified index of each child list. + +* \(2) The top level of nesting is removed: all values in the list child array, including nulls, are appended to the output. However, nulls in the parent list array are discarded. -* \(2) For each value in the list child array, the index at which it is found +* \(3) For each value in the list child array, the index at which it is found in the list array is appended to the output. Nulls in the parent list array are discarded. Output type is Int32 for List and FixedSizeList, Int64 for LargeList. diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index dff2a0052f4..00a59b8eef9 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -371,6 +371,7 @@ Structural Transforms is_nan is_null is_valid - list_value_length + list_element list_flatten list_parent_indices + list_value_length diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index b33e482df89..b3e0a41c597 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -2133,3 +2133,21 @@ def test_case_when(): [False, True, None]), [1, 2, 3], [11, 12, 13]) == pa.array([1, 12, None]) + + +def test_list_element(): + element_type = pa.struct([('a', pa.float64()), ('b', pa.int8())]) + list_type = pa.list_(element_type) + l1 = [{'a': .4, 'b': 2}, None, {'a': .2, 'b': 4}, None, {'a': 5.6, 'b': 6}] + l2 = [None, {'a': .52, 'b': 3}, {'a': .7, 'b': 4}, None, {'a': .6, 'b': 8}] + lists = pa.array([l1, l2], list_type) + + index = 1 + result = pa.compute.list_element(lists, index) + expected = pa.array([None, {'a': 0.52, 'b': 3}], element_type) + assert result.equals(expected) + + index = 4 + result = pa.compute.list_element(lists, index) + expected = pa.array([{'a': 5.6, 'b': 6}, {'a': .6, 'b': 8}], element_type) + assert result.equals(expected)