From 7b840c7dc01799045844fb37def1da8cb69c1c89 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 12 Mar 2019 16:55:22 -0400 Subject: [PATCH 01/14] first draft of take kernel impl --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/array/builder_binary.cc | 4 +- cpp/src/arrow/array/builder_binary.h | 17 ++ cpp/src/arrow/compute/kernels/CMakeLists.txt | 1 + cpp/src/arrow/compute/kernels/take-test.cc | 143 +++++++++++++ cpp/src/arrow/compute/kernels/take.cc | 213 +++++++++++++++++++ cpp/src/arrow/compute/kernels/take.h | 71 +++++++ cpp/src/arrow/testing/gtest_util.h | 5 + cpp/src/arrow/type_traits.h | 4 +- 9 files changed, 455 insertions(+), 4 deletions(-) create mode 100644 cpp/src/arrow/compute/kernels/take-test.cc create mode 100644 cpp/src/arrow/compute/kernels/take.cc create mode 100644 cpp/src/arrow/compute/kernels/take.h diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 6854f14f068..de48c77463b 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -158,6 +158,7 @@ if(ARROW_COMPUTE) compute/kernels/hash.cc compute/kernels/mean.cc compute/kernels/sum.cc + compute/kernels/take.cc compute/kernels/util-internal.cc compute/operations/cast.cc compute/operations/literal.cc) diff --git a/cpp/src/arrow/array/builder_binary.cc b/cpp/src/arrow/array/builder_binary.cc index 4fef135b203..26c6cb4d372 100644 --- a/cpp/src/arrow/array/builder_binary.cc +++ b/cpp/src/arrow/array/builder_binary.cc @@ -232,8 +232,8 @@ Status FixedSizeBinaryBuilder::AppendValues(const uint8_t* data, int64_t length, Status FixedSizeBinaryBuilder::AppendNull() { RETURN_NOT_OK(Reserve(1)); - UnsafeAppendToBitmap(false); - return byte_builder_.Advance(byte_width_); + UnsafeAppendNull(); + return Status::OK(); } void FixedSizeBinaryBuilder::Reset() { diff --git a/cpp/src/arrow/array/builder_binary.h b/cpp/src/arrow/array/builder_binary.h index c3a459b39fc..a75512fb653 100644 --- a/cpp/src/arrow/array/builder_binary.h +++ b/cpp/src/arrow/array/builder_binary.h @@ -189,6 +189,18 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { return byte_builder_.Append(value, byte_width_); } + void UnsafeAppend(const uint8_t* value) { + UnsafeAppendToBitmap(true); + byte_builder_.UnsafeAppend(value, byte_width_); + } + + void UnsafeAppend(util::string_view value) { +#ifndef NDEBUG + CheckValueSize(static_cast(value.size())); +#endif + UnsafeAppend(reinterpret_cast(value.data())); + } + Status Append(const char* value) { return Append(reinterpret_cast(value)); } @@ -218,6 +230,11 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { const uint8_t* valid_bytes = NULLPTR); Status AppendNull(); + void UnsafeAppendNull() { + UnsafeAppendToBitmap(false); + byte_builder_.UnsafeAdvance(byte_width_); + } + void Reset() override; Status Resize(int64_t capacity) override; Status FinishInternal(std::shared_ptr* out) override; diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 5d78747bf93..abdc092a590 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -20,6 +20,7 @@ arrow_install_all_headers("arrow/compute/kernels") add_arrow_test(boolean-test PREFIX "arrow-compute") add_arrow_test(cast-test PREFIX "arrow-compute") add_arrow_test(hash-test PREFIX "arrow-compute") +add_arrow_test(take-test PREFIX "arrow-compute") add_arrow_test(util-internal-test PREFIX "arrow-compute") # Aggregates diff --git a/cpp/src/arrow/compute/kernels/take-test.cc b/cpp/src/arrow/compute/kernels/take-test.cc new file mode 100644 index 00000000000..b2094b994b8 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/take-test.cc @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// returnGegarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/compute/context.h" +#include "arrow/compute/kernels/take.h" +#include "arrow/compute/test-util.h" +#include "arrow/testing/gtest_common.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/testing/util.h" + +namespace arrow { +namespace compute { + +using util::string_view; + +template +class TestTakeKernel : public ComputeFixture, public TestBase { + protected: + void AssertTake(const std::shared_ptr& type, const std::string& values, + const std::string& indices, TakeOptions options, + const std::string& expected) { + std::shared_ptr actual; + ASSERT_OK(this->Take(type, values, indices, options, &actual)); + AssertArraysEqual(*ArrayFromJSON(type, expected), *actual); + } + Status Take(const std::shared_ptr& type, const std::string& values, + const std::string& indices, TakeOptions options, + std::shared_ptr* out) { + return arrow::compute::Take(&this->ctx_, *ArrayFromJSON(type, values), + *ArrayFromJSON(int8(), indices), options, out); + } +}; + +class TestTakeKernelWithNull : public TestTakeKernel { + protected: + void AssertTake(const std::string& values, const std::string& indices, + TakeOptions options, const std::string& expected) { + TestTakeKernel::AssertTake(utf8(), values, indices, options, expected); + } +}; + +TEST_F(TestTakeKernelWithNull, TakeNull) { + TakeOptions options; + this->AssertTake("[null, null, null]", "[0, 1, 0]", options, "[null, null, null]"); + + std::shared_ptr arr; + ASSERT_RAISES(Invalid, + this->Take(null(), "[null, null, null]", "[0, 9, 0]", options, &arr)); +} + +class TestTakeKernelWithBoolean : public TestTakeKernel { + protected: + void AssertTake(const std::string& values, const std::string& indices, + TakeOptions options, const std::string& expected) { + TestTakeKernel::AssertTake(boolean(), values, indices, options, + expected); + } +}; + +TEST_F(TestTakeKernelWithBoolean, TakeBoolean) { + TakeOptions options; + this->AssertTake("[true, false, true]", "[0, 1, 0]", options, "[true, false, true]"); + this->AssertTake("[null, false, true]", "[0, 1, 0]", options, "[null, false, null]"); + this->AssertTake("[true, false, true]", "[null, 1, 0]", options, "[null, false, true]"); + + std::shared_ptr arr; + ASSERT_RAISES(Invalid, + this->Take(boolean(), "[true, false, true]", "[0, 9, 0]", options, &arr)); + + options.out_of_bounds = TakeOptions::TONULL; + this->AssertTake("[true, false, true]", "[0, 9, 0]", options, "[true, null, true]"); +} + +template +class TestTakeKernelWithNumeric : public TestTakeKernel { + protected: + void AssertTake(const std::string& values, const std::string& indices, + TakeOptions options, const std::string& expected) { + TestTakeKernel::AssertTake(type_singleton(), values, indices, options, + expected); + } + std::shared_ptr type_singleton() { + return TypeTraits::type_singleton(); + } +}; + +TYPED_TEST_CASE(TestTakeKernelWithNumeric, NumericArrowTypes); +TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) { + TakeOptions options; + this->AssertTake("[7, 8, 9]", "[0, 1, 0]", options, "[7, 8, 7]"); + this->AssertTake("[null, 8, 9]", "[0, 1, 0]", options, "[null, 8, null]"); + this->AssertTake("[7, 8, 9]", "[null, 1, 0]", options, "[null, 8, 7]"); + + std::shared_ptr arr; + ASSERT_RAISES(Invalid, this->Take(this->type_singleton(), "[7, 8, 9]", "[0, 9, 0]", + options, &arr)); + + options.out_of_bounds = TakeOptions::TONULL; + this->AssertTake("[7, 8, 9]", "[0, 9, 0]", options, "[7, null, 7]"); +} + +class TestTakeKernelWithString : public TestTakeKernel { + protected: + void AssertTake(const std::string& values, const std::string& indices, + TakeOptions options, const std::string& expected) { + TestTakeKernel::AssertTake(utf8(), values, indices, options, expected); + } +}; + +TEST_F(TestTakeKernelWithString, TakeString) { + TakeOptions options; + this->AssertTake(R"(["a", "b", "c"])", "[0, 1, 0]", options, R"(["a", "b", "a"])"); + this->AssertTake(R"([null, "b", "c"])", "[0, 1, 0]", options, "[null, \"b\", null]"); + this->AssertTake(R"(["a", "b", "c"])", "[null, 1, 0]", options, R"([null, "b", "a"])"); + + std::shared_ptr arr; + ASSERT_RAISES(Invalid, + this->Take(utf8(), R"(["a", "b", "c"])", "[0, 9, 0]", options, &arr)); + + options.out_of_bounds = TakeOptions::TONULL; + this->AssertTake(R"(["a", "b", "c"])", "[0, 9, 0]", options, R"(["a", null, "a"])"); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/take.cc b/cpp/src/arrow/compute/kernels/take.cc new file mode 100644 index 00000000000..79ca60a6234 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/take.cc @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// returnGegarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/builder.h" +#include "arrow/compute/context.h" +#include "arrow/compute/kernels/take.h" +#include "arrow/visitor_inline.h" + +namespace arrow { +namespace compute { + +Status Take(FunctionContext* context, const Array& values, const Array& indices, + const TakeOptions& options, std::shared_ptr* out) { + TakeKernel kernel(values.type(), options); + Datum out_datum; + RETURN_NOT_OK(kernel.Call(context, values.data(), indices.data(), &out_datum)); + *out = MakeArray(out_datum.array()); + return Status::OK(); +} + +struct TakeParameters { + FunctionContext* context; + std::shared_ptr values, indices; + TakeOptions options; + std::shared_ptr* out; +}; + +template +Status UnsafeAppend(Builder* builder, Scalar&& value) { + builder->UnsafeAppend(std::forward(value)); + return Status::OK(); +} + +Status UnsafeAppend(BinaryBuilder* builder, util::string_view value) { + RETURN_NOT_OK(builder->ReserveData(static_cast(value.size()))); + builder->UnsafeAppend(value); + return Status::OK(); +} + +Status UnsafeAppend(StringBuilder* builder, util::string_view value) { + RETURN_NOT_OK(builder->ReserveData(static_cast(value.size()))); + builder->UnsafeAppend(value); + return Status::OK(); +} + +template +Status TakeImpl(FunctionContext*, const ValueArray& values, const IndexArray& indices, + OutBuilder* builder) { + for (int64_t i = 0; i != indices.length(); ++i) { + if (!AllIndicesValid && indices.IsNull(i)) { + builder->UnsafeAppendNull(); + continue; + } + auto index = indices.raw_values()[i]; + if (OutOfBounds == TakeOptions::ERROR && + static_cast(index) >= values.length()) { + return Status::Invalid("take index out of bounds"); + } + if (OutOfBounds == TakeOptions::TONULL && + static_cast(index) >= values.length()) { + builder->UnsafeAppendNull(); + continue; + } + if (!AllValuesValid && values.IsNull(index)) { + builder->UnsafeAppendNull(); + continue; + } + RETURN_NOT_OK(UnsafeAppend(builder, values.GetView(index))); + } + return Status::OK(); +} + +template +Status TakeImpl(FunctionContext* context, const ValueArray& values, + const IndexArray& indices, OutBuilder* builder) { + if (indices.null_count() == 0) { + return TakeImpl(context, values, indices, builder); + } + return TakeImpl(context, values, indices, builder); +} + +template +Status TakeImpl(FunctionContext* context, const ValueArray& values, + const IndexArray& indices, OutBuilder* builder) { + if (values.null_count() == 0) { + return TakeImpl(context, values, indices, builder); + } + return TakeImpl(context, values, indices, builder); +} + +template +Status TakeImpl(FunctionContext* context, const ValueArray& values, + const IndexArray& indices, const TakeOptions& options, + OutBuilder* builder) { + switch (options.out_of_bounds) { + case TakeOptions::ERROR: + return TakeImpl(context, values, indices, builder); + case TakeOptions::TONULL: + return TakeImpl(context, values, indices, builder); + case TakeOptions::UNSAFE: + return TakeImpl(context, values, indices, builder); + } +} + +template +struct UnpackValues { + using IndexArrayRef = const typename TypeTraits::ArrayType&; + + template + Status Visit(const ValueType&) { + using ValueArrayRef = const typename TypeTraits::ArrayType&; + using OutBuilder = typename TypeTraits::BuilderType; + IndexArrayRef indices = static_cast(*params_.indices); + ValueArrayRef values = static_cast(*params_.values); + std::unique_ptr builder; + RETURN_NOT_OK(MakeBuilder(params_.context->memory_pool(), values.type(), &builder)); + RETURN_NOT_OK(builder->Reserve(indices.length())); + RETURN_NOT_OK(TakeImpl(params_.context, values, indices, params_.options, + static_cast(builder.get()))); + return builder->Finish(params_.out); + } + + Status Visit(const NullType& t) { + auto indices_length = params_.indices->length(); + if (params_.options.out_of_bounds == TakeOptions::ERROR && indices_length != 0) { + auto indices = static_cast(*params_.indices).raw_values(); + auto max = *std::max_element(indices, indices + indices_length); + if (static_cast(max) > params_.values->length()) { + return Status::Invalid("out of bounds index"); + } + } + params_.out->reset(new NullArray(indices_length)); + return Status::OK(); + } + + Status Visit(const DictionaryType& t) { + UnpackValues unpack = {params_}; + RETURN_NOT_OK(VisitTypeInline(*t.index_type(), &unpack)); + (*params_.out)->data()->type = dictionary(t.index_type(), t.dictionary()); + return Status::OK(); + } + + Status Visit(const ExtensionType& t) { + // XXX can we just take from its storage? + return Status::NotImplemented("gathering values of type ", t); + } + + Status Visit(const UnionType& t) { + return Status::NotImplemented("gathering values of type ", t); + } + + Status Visit(const ListType& t) { + return Status::NotImplemented("gathering values of type ", t); + } + + Status Visit(const StructType& t) { + return Status::NotImplemented("gathering values of type ", t); + } + + const TakeParameters& params_; +}; + +struct UnpackIndices { + template + enable_if_integer Visit(const IndexType&) { + UnpackValues unpack = {params_}; + return VisitTypeInline(*params_.values->type(), &unpack); + } + Status Visit(const DataType& other) { + return Status::Invalid("index type not supported: ", other); + } + const TakeParameters& params_; +}; + +Status TakeKernel::Call(FunctionContext* ctx, const Datum& values, const Datum& indices, + Datum* out) { + if (!values.is_array() || !indices.is_array()) { + return Status::Invalid("TakeKernel expects array values and indices"); + } + std::shared_ptr out_array; + TakeParameters params; + params.context = ctx; + params.values = MakeArray(values.array()); + params.indices = MakeArray(indices.array()); + params.options = options_; + params.out = &out_array; + UnpackIndices unpack = {params}; + RETURN_NOT_OK(VisitTypeInline(*indices.type(), &unpack)); + *out = Datum(out_array); + return Status::OK(); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/take.h b/cpp/src/arrow/compute/kernels/take.h new file mode 100644 index 00000000000..336cfee6231 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/take.h @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/compute/kernel.h" +#include "arrow/status.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class Array; + +namespace compute { + +class FunctionContext; + +struct ARROW_EXPORT TakeOptions { + enum { + // indices out of bounds will raise an error + ERROR, + // indices out of bounds will result in a null value + TONULL, + // indices out of bounds is undefined behavior + UNSAFE + } out_of_bounds = ERROR; +}; + +/// \brief Take from one array type to another +/// \param[in] context the FunctionContext +/// \param[in] values array from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[out] out resulting array +ARROW_EXPORT +Status Take(FunctionContext* context, const Array& values, const Array& indices, + const TakeOptions& options, std::shared_ptr* out); + +/// \brief BinaryKernel implementing Take operation +class ARROW_EXPORT TakeKernel : public BinaryKernel { + public: + explicit TakeKernel(const std::shared_ptr& type, TakeOptions options = {}) + : type_(type), options_(options) {} + + Status Call(FunctionContext* ctx, const Datum& values, const Datum& indices, + Datum* out) override; + + std::shared_ptr out_type() const override { return type_; } + + private: + std::shared_ptr type_; + TakeOptions options_; +}; +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 88f3d129c29..e9aa5b80ed5 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -100,6 +100,11 @@ namespace arrow { typedef ::testing::Types NumericArrowTypes; +typedef ::testing::Types + ParameterFreeArrowTypes; class ChunkedArray; class Column; diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 0cea58483f4..263916fd639 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -304,9 +304,9 @@ using enable_if_primitive_ctype = template using enable_if_date = typename std::enable_if::value>::type; -template +template using enable_if_integer = - typename std::enable_if::value>::type; + typename std::enable_if::value, U>::type; template using enable_if_signed_integer = From 5fd6a2126d4758390c4da7b26f22ea66ac4ac8e9 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 12 Mar 2019 21:34:42 -0400 Subject: [PATCH 02/14] add default case for switch(out_of_bounds) --- cpp/src/arrow/compute/kernels/take.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/take.cc b/cpp/src/arrow/compute/kernels/take.cc index 79ca60a6234..b2efe9565ec 100644 --- a/cpp/src/arrow/compute/kernels/take.cc +++ b/cpp/src/arrow/compute/kernels/take.cc @@ -118,6 +118,9 @@ Status TakeImpl(FunctionContext* context, const ValueArray& values, return TakeImpl(context, values, indices, builder); case TakeOptions::UNSAFE: return TakeImpl(context, values, indices, builder); + default: + ARROW_LOG(FATAL) << "how did we get here?"; + return Status::OK(); } } From e94fe312bfdb538617722c4f51250908a8f20a7c Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 12 Mar 2019 21:58:01 -0400 Subject: [PATCH 03/14] use Datum::make_array --- cpp/src/arrow/compute/kernels/take.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/take.cc b/cpp/src/arrow/compute/kernels/take.cc index b2efe9565ec..a3d79a3b2c0 100644 --- a/cpp/src/arrow/compute/kernels/take.cc +++ b/cpp/src/arrow/compute/kernels/take.cc @@ -21,6 +21,7 @@ #include "arrow/builder.h" #include "arrow/compute/context.h" #include "arrow/compute/kernels/take.h" +#include "arrow/util/logging.h" #include "arrow/visitor_inline.h" namespace arrow { @@ -31,7 +32,7 @@ Status Take(FunctionContext* context, const Array& values, const Array& indices, TakeKernel kernel(values.type(), options); Datum out_datum; RETURN_NOT_OK(kernel.Call(context, values.data(), indices.data(), &out_datum)); - *out = MakeArray(out_datum.array()); + *out = out_datum.make_array(); return Status::OK(); } @@ -119,8 +120,8 @@ Status TakeImpl(FunctionContext* context, const ValueArray& values, case TakeOptions::UNSAFE: return TakeImpl(context, values, indices, builder); default: - ARROW_LOG(FATAL) << "how did we get here?"; - return Status::OK(); + ARROW_LOG(FATAL) << "how did we get here"; + return Status::NotImplemented("how did we get here"); } } @@ -202,8 +203,8 @@ Status TakeKernel::Call(FunctionContext* ctx, const Datum& values, const Datum& std::shared_ptr out_array; TakeParameters params; params.context = ctx; - params.values = MakeArray(values.array()); - params.indices = MakeArray(indices.array()); + params.values = values.make_array(); + params.indices = indices.make_array(); params.options = options_; params.out = &out_array; UnpackIndices unpack = {params}; From 14e837ef868afbc42c4828dea9cceaa99dbd0ed4 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 13 Mar 2019 14:44:33 -0400 Subject: [PATCH 04/14] add better explanatory comment for Take --- cpp/src/arrow/compute/kernels/take.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/take.h b/cpp/src/arrow/compute/kernels/take.h index 336cfee6231..5d398181c6e 100644 --- a/cpp/src/arrow/compute/kernels/take.h +++ b/cpp/src/arrow/compute/kernels/take.h @@ -43,6 +43,14 @@ struct ARROW_EXPORT TakeOptions { }; /// \brief Take from one array type to another +/// +/// The output array will be of the same type as the input values +/// array, with elements taken from the values array at the given +/// indices. +/// +/// For example given values ["a", "b", "c"] and indices +/// [2, 1], the output will be ["c", "b"]. +/// /// \param[in] context the FunctionContext /// \param[in] values array from which to take /// \param[in] indices which values to take From a7dd739757d1f0593ef80c30634a459e7dca95fe Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 14 Mar 2019 12:11:48 -0400 Subject: [PATCH 05/14] explain null behavior in Take doccomment --- cpp/src/arrow/compute/kernels/take.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/take.h b/cpp/src/arrow/compute/kernels/take.h index 5d398181c6e..1d42656f13a 100644 --- a/cpp/src/arrow/compute/kernels/take.h +++ b/cpp/src/arrow/compute/kernels/take.h @@ -42,14 +42,16 @@ struct ARROW_EXPORT TakeOptions { } out_of_bounds = ERROR; }; -/// \brief Take from one array type to another +/// \brief Take from an array of values at indices in another array /// /// The output array will be of the same type as the input values /// array, with elements taken from the values array at the given -/// indices. +/// indices. If an index is null then the taken element will null. /// -/// For example given values ["a", "b", "c"] and indices -/// [2, 1], the output will be ["c", "b"]. +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// indices = [2, 1, null, 3], the output will be +/// = [values[2], values[1], null, values[3]] +/// = ["c", "b", null, null] /// /// \param[in] context the FunctionContext /// \param[in] values array from which to take From c5fd669552f321ce4cd5484cf58dabb909a6298f Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 15 Mar 2019 14:03:41 -0400 Subject: [PATCH 06/14] add test for taking from DictionaryArrays --- cpp/src/arrow/compute/kernels/take-test.cc | 33 ++++++++++++++++++++++ cpp/src/arrow/compute/kernels/take.cc | 6 +++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/take-test.cc b/cpp/src/arrow/compute/kernels/take-test.cc index b2094b994b8..b090cd3b634 100644 --- a/cpp/src/arrow/compute/kernels/take-test.cc +++ b/cpp/src/arrow/compute/kernels/take-test.cc @@ -34,6 +34,13 @@ using util::string_view; template class TestTakeKernel : public ComputeFixture, public TestBase { protected: + void AssertTake(const std::shared_ptr& values, + const std::shared_ptr& indices, TakeOptions options, + const std::shared_ptr& expected) { + std::shared_ptr actual; + ASSERT_OK(arrow::compute::Take(&this->ctx_, *values, *indices, options, actual)); + AssertArraysEqual(*expected, *actual); + } void AssertTake(const std::shared_ptr& type, const std::string& values, const std::string& indices, TakeOptions options, const std::string& expected) { @@ -123,6 +130,20 @@ class TestTakeKernelWithString : public TestTakeKernel { TakeOptions options, const std::string& expected) { TestTakeKernel::AssertTake(utf8(), values, indices, options, expected); } + void AssertTakeDictionary(const std::string& dictionary_values, + const std::string& dictionary_indices, + const std::string& indices, TakeOptions options, + const std::string& expected_indices) { + auto type = dictionary(int8(), ArrayFromJSON(utf8(), dictionary_values)); + std::shared_ptr values, actual, expected; + ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), dictionary_indices), + &values)); + ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_indices), + &expected)); + ASSERT_OK(arrow::compute::Take(&this->ctx_, *values, *ArrayFromJSON(int8(), indices), + options, &actual)); + AssertArraysEqual(*expected, *actual); + } }; TEST_F(TestTakeKernelWithString, TakeString) { @@ -139,5 +160,17 @@ TEST_F(TestTakeKernelWithString, TakeString) { this->AssertTake(R"(["a", "b", "c"])", "[0, 9, 0]", options, R"(["a", null, "a"])"); } +TEST_F(TestTakeKernelWithString, TakeDictionary) { + TakeOptions options; + auto dict = R"(["a", "b", "c", "d", "e"])"; + this->AssertTakeDictionary(dict, "[0, 1, 4]", "[0, 1, 0]", options, "[0, 1, 0]"); + this->AssertTakeDictionary(dict, "[null, 1, 4]", "[0, 1, 0]", options, + "[null, 1, null]"); + this->AssertTakeDictionary(dict, "[0, 1, 4]", "[null, 1, 0]", options, "[null, 1, 0]"); + + options.out_of_bounds = TakeOptions::TONULL; + this->AssertTakeDictionary(dict, "[0, 1, 4]", "[0, 9, 0]", options, "[0, null, 0]"); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/take.cc b/cpp/src/arrow/compute/kernels/take.cc index a3d79a3b2c0..289ec63e4d9 100644 --- a/cpp/src/arrow/compute/kernels/take.cc +++ b/cpp/src/arrow/compute/kernels/take.cc @@ -157,7 +157,11 @@ struct UnpackValues { } Status Visit(const DictionaryType& t) { - UnpackValues unpack = {params_}; + auto dictionary_indices = params_.values->data()->Copy(); + dictionary_indices->type = t.index_type(); + TakeParameters params = params_; + params.values = MakeArray(dictionary_indices); + UnpackValues unpack = {params}; RETURN_NOT_OK(VisitTypeInline(*t.index_type(), &unpack)); (*params_.out)->data()->type = dictionary(t.index_type(), t.dictionary()); return Status::OK(); From 4a6932fd4ed625d03399b790fdc7d6bcc329627b Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 18 Mar 2019 15:05:18 -0400 Subject: [PATCH 07/14] add take kernel to api.h --- cpp/src/arrow/compute/api.h | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/arrow/compute/api.h b/cpp/src/arrow/compute/api.h index b6e609a71e1..eb0e7897e28 100644 --- a/cpp/src/arrow/compute/api.h +++ b/cpp/src/arrow/compute/api.h @@ -27,5 +27,6 @@ #include "arrow/compute/kernels/hash.h" // IWYU pragma: export #include "arrow/compute/kernels/mean.h" // IWYU pragma: export #include "arrow/compute/kernels/sum.h" // IWYU pragma: export +#include "arrow/compute/kernels/take.h" // IWYU pragma: export #endif // ARROW_COMPUTE_API_H From 8be7df1c7ee7fb6201753877489fb3c080f2cf69 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 19 Mar 2019 15:51:38 -0400 Subject: [PATCH 08/14] fix take-test --- cpp/src/arrow/compute/kernels/take-test.cc | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/take-test.cc b/cpp/src/arrow/compute/kernels/take-test.cc index b090cd3b634..f88e79383f0 100644 --- a/cpp/src/arrow/compute/kernels/take-test.cc +++ b/cpp/src/arrow/compute/kernels/take-test.cc @@ -34,11 +34,11 @@ using util::string_view; template class TestTakeKernel : public ComputeFixture, public TestBase { protected: - void AssertTake(const std::shared_ptr& values, - const std::shared_ptr& indices, TakeOptions options, - const std::shared_ptr& expected) { + void AssertTakeArrays(const std::shared_ptr& values, + const std::shared_ptr& indices, TakeOptions options, + const std::shared_ptr& expected) { std::shared_ptr actual; - ASSERT_OK(arrow::compute::Take(&this->ctx_, *values, *indices, options, actual)); + ASSERT_OK(arrow::compute::Take(&this->ctx_, *values, *indices, options, &actual)); AssertArraysEqual(*expected, *actual); } void AssertTake(const std::shared_ptr& type, const std::string& values, @@ -140,9 +140,8 @@ class TestTakeKernelWithString : public TestTakeKernel { &values)); ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_indices), &expected)); - ASSERT_OK(arrow::compute::Take(&this->ctx_, *values, *ArrayFromJSON(int8(), indices), - options, &actual)); - AssertArraysEqual(*expected, *actual); + auto take_indices = ArrayFromJSON(int8(), indices); + this->AssertTakeArrays(values, take_indices, options, expected); } }; From 198320d7aa83f8e23ae3599094aaaa94b438241e Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 19 Mar 2019 17:04:29 -0400 Subject: [PATCH 09/14] incorporate Francois' suggestions --- cpp/src/arrow/array/builder_binary.h | 19 +++--- cpp/src/arrow/compute/kernels/take-test.cc | 8 +-- cpp/src/arrow/compute/kernels/take.cc | 75 +++++++++++++--------- cpp/src/arrow/compute/kernels/take.h | 23 +++++-- 4 files changed, 78 insertions(+), 47 deletions(-) diff --git a/cpp/src/arrow/array/builder_binary.h b/cpp/src/arrow/array/builder_binary.h index a75512fb653..be4579cf8f4 100644 --- a/cpp/src/arrow/array/builder_binary.h +++ b/cpp/src/arrow/array/builder_binary.h @@ -195,9 +195,7 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { } void UnsafeAppend(util::string_view value) { -#ifndef NDEBUG - CheckValueSize(static_cast(value.size())); -#endif + CheckValueSize(value); UnsafeAppend(reinterpret_cast(value.data())); } @@ -206,16 +204,12 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { } Status Append(const util::string_view& view) { -#ifndef NDEBUG - CheckValueSize(static_cast(view.size())); -#endif + CheckValueSize(view); return Append(reinterpret_cast(view.data())); } Status Append(const std::string& s) { -#ifndef NDEBUG - CheckValueSize(static_cast(s.size())); -#endif + CheckValueSize(s); return Append(reinterpret_cast(s.data())); } @@ -258,6 +252,13 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { int32_t byte_width_; BufferBuilder byte_builder_; + template + void CheckValueSize(const Sized& s, decltype(s.size())* = nullptr) { +#ifndef NDEBUG + CheckValueSize(static_cast(s.size())); +#endif + } + #ifndef NDEBUG void CheckValueSize(int64_t size); #endif diff --git a/cpp/src/arrow/compute/kernels/take-test.cc b/cpp/src/arrow/compute/kernels/take-test.cc index f88e79383f0..6fd5ec446b7 100644 --- a/cpp/src/arrow/compute/kernels/take-test.cc +++ b/cpp/src/arrow/compute/kernels/take-test.cc @@ -92,7 +92,7 @@ TEST_F(TestTakeKernelWithBoolean, TakeBoolean) { ASSERT_RAISES(Invalid, this->Take(boolean(), "[true, false, true]", "[0, 9, 0]", options, &arr)); - options.out_of_bounds = TakeOptions::TONULL; + options.out_of_bounds = TakeOptions::TO_NULL; this->AssertTake("[true, false, true]", "[0, 9, 0]", options, "[true, null, true]"); } @@ -120,7 +120,7 @@ TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) { ASSERT_RAISES(Invalid, this->Take(this->type_singleton(), "[7, 8, 9]", "[0, 9, 0]", options, &arr)); - options.out_of_bounds = TakeOptions::TONULL; + options.out_of_bounds = TakeOptions::TO_NULL; this->AssertTake("[7, 8, 9]", "[0, 9, 0]", options, "[7, null, 7]"); } @@ -155,7 +155,7 @@ TEST_F(TestTakeKernelWithString, TakeString) { ASSERT_RAISES(Invalid, this->Take(utf8(), R"(["a", "b", "c"])", "[0, 9, 0]", options, &arr)); - options.out_of_bounds = TakeOptions::TONULL; + options.out_of_bounds = TakeOptions::TO_NULL; this->AssertTake(R"(["a", "b", "c"])", "[0, 9, 0]", options, R"(["a", null, "a"])"); } @@ -167,7 +167,7 @@ TEST_F(TestTakeKernelWithString, TakeDictionary) { "[null, 1, null]"); this->AssertTakeDictionary(dict, "[0, 1, 4]", "[null, 1, 0]", options, "[null, 1, 0]"); - options.out_of_bounds = TakeOptions::TONULL; + options.out_of_bounds = TakeOptions::TO_NULL; this->AssertTakeDictionary(dict, "[0, 1, 4]", "[0, 9, 0]", options, "[0, null, 0]"); } diff --git a/cpp/src/arrow/compute/kernels/take.cc b/cpp/src/arrow/compute/kernels/take.cc index 289ec63e4d9..69ea4789550 100644 --- a/cpp/src/arrow/compute/kernels/take.cc +++ b/cpp/src/arrow/compute/kernels/take.cc @@ -29,13 +29,20 @@ namespace compute { Status Take(FunctionContext* context, const Array& values, const Array& indices, const TakeOptions& options, std::shared_ptr* out) { - TakeKernel kernel(values.type(), options); Datum out_datum; - RETURN_NOT_OK(kernel.Call(context, values.data(), indices.data(), &out_datum)); + RETURN_NOT_OK( + Take(context, Datum(values.data()), Datum(indices.data()), options, &out_datum)); *out = out_datum.make_array(); return Status::OK(); } +Status Take(FunctionContext* context, const Datum& values, const Datum& indices, + const TakeOptions& options, Datum* out) { + TakeKernel kernel(values.type(), options); + RETURN_NOT_OK(kernel.Call(context, values, indices, out)); + return Status::OK(); +} + struct TakeParameters { FunctionContext* context; std::shared_ptr values, indices; @@ -61,22 +68,21 @@ Status UnsafeAppend(StringBuilder* builder, util::string_view value) { return Status::OK(); } -template +template Status TakeImpl(FunctionContext*, const ValueArray& values, const IndexArray& indices, OutBuilder* builder) { - for (int64_t i = 0; i != indices.length(); ++i) { + auto raw_indices = indices.raw_values(); + for (int64_t i = 0; i < indices.length(); ++i) { if (!AllIndicesValid && indices.IsNull(i)) { builder->UnsafeAppendNull(); continue; } - auto index = indices.raw_values()[i]; - if (OutOfBounds == TakeOptions::ERROR && - static_cast(index) >= values.length()) { + auto index = static_cast(raw_indices[i]); + if (B == TakeOptions::ERROR && (index < 0 || index >= values.length())) { return Status::Invalid("take index out of bounds"); } - if (OutOfBounds == TakeOptions::TONULL && - static_cast(index) >= values.length()) { + if (B == TakeOptions::TO_NULL && (index < 0 || index >= values.length())) { builder->UnsafeAppendNull(); continue; } @@ -89,23 +95,24 @@ Status TakeImpl(FunctionContext*, const ValueArray& values, const IndexArray& in return Status::OK(); } -template +template Status TakeImpl(FunctionContext* context, const ValueArray& values, const IndexArray& indices, OutBuilder* builder) { if (indices.null_count() == 0) { - return TakeImpl(context, values, indices, builder); + return TakeImpl(context, values, indices, builder); } - return TakeImpl(context, values, indices, builder); + return TakeImpl(context, values, indices, builder); } -template +template Status TakeImpl(FunctionContext* context, const ValueArray& values, const IndexArray& indices, OutBuilder* builder) { if (values.null_count() == 0) { - return TakeImpl(context, values, indices, builder); + return TakeImpl(context, values, indices, builder); } - return TakeImpl(context, values, indices, builder); + return TakeImpl(context, values, indices, builder); } template @@ -115,8 +122,8 @@ Status TakeImpl(FunctionContext* context, const ValueArray& values, switch (options.out_of_bounds) { case TakeOptions::ERROR: return TakeImpl(context, values, indices, builder); - case TakeOptions::TONULL: - return TakeImpl(context, values, indices, builder); + case TakeOptions::TO_NULL: + return TakeImpl(context, values, indices, builder); case TakeOptions::UNSAFE: return TakeImpl(context, values, indices, builder); default: @@ -147,8 +154,10 @@ struct UnpackValues { auto indices_length = params_.indices->length(); if (params_.options.out_of_bounds == TakeOptions::ERROR && indices_length != 0) { auto indices = static_cast(*params_.indices).raw_values(); - auto max = *std::max_element(indices, indices + indices_length); - if (static_cast(max) > params_.values->length()) { + auto minmax = std::minmax_element(indices, indices + indices_length); + auto min = static_cast(*minmax.first); + auto max = static_cast(*minmax.second); + if (min < 0 || max >= params_.values->length()) { return Status::Invalid("out of bounds index"); } } @@ -157,14 +166,20 @@ struct UnpackValues { } Status Visit(const DictionaryType& t) { - auto dictionary_indices = params_.values->data()->Copy(); - dictionary_indices->type = t.index_type(); - TakeParameters params = params_; - params.values = MakeArray(dictionary_indices); - UnpackValues unpack = {params}; - RETURN_NOT_OK(VisitTypeInline(*t.index_type(), &unpack)); - (*params_.out)->data()->type = dictionary(t.index_type(), t.dictionary()); - return Status::OK(); + std::shared_ptr taken_indices; + { + // To take from a dictionary, apply the current kernel to the dictionary's + // indices. (Use UnpackValues since IndexType is already unpacked) + auto indices = static_cast(params_.values.get())->indices(); + TakeParameters params = params_; + params.values = indices; + params.out = &taken_indices; + UnpackValues unpack = {params}; + RETURN_NOT_OK(VisitTypeInline(*t.index_type(), &unpack)); + } + // create output dictionary from taken indices + return DictionaryArray::FromArrays(dictionary(t.index_type(), t.dictionary()), + taken_indices, params_.out); } Status Visit(const ExtensionType& t) { @@ -193,9 +208,11 @@ struct UnpackIndices { UnpackValues unpack = {params_}; return VisitTypeInline(*params_.values->type(), &unpack); } + Status Visit(const DataType& other) { return Status::Invalid("index type not supported: ", other); } + const TakeParameters& params_; }; diff --git a/cpp/src/arrow/compute/kernels/take.h b/cpp/src/arrow/compute/kernels/take.h index 1d42656f13a..ad2ec7004e3 100644 --- a/cpp/src/arrow/compute/kernels/take.h +++ b/cpp/src/arrow/compute/kernels/take.h @@ -32,12 +32,14 @@ namespace compute { class FunctionContext; struct ARROW_EXPORT TakeOptions { - enum { - // indices out of bounds will raise an error + enum OutOfBoundsBehavior { + // Out of bounds indices will raise an error ERROR, - // indices out of bounds will result in a null value - TONULL, - // indices out of bounds is undefined behavior + // Out of bounds indices will result in a null value + TO_NULL, + // Bounds checking will be skipped, which is faster. + // Only use this if indices are known to be within bounds; + // out of bounds indices will result in undefined behavior UNSAFE } out_of_bounds = ERROR; }; @@ -62,6 +64,17 @@ ARROW_EXPORT Status Take(FunctionContext* context, const Array& values, const Array& indices, const TakeOptions& options, std::shared_ptr* out); +/// \brief Take from an array of values at indices in another array +/// +/// \param[in] context the FunctionContext +/// \param[in] values datum from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[out] out resulting datum +ARROW_EXPORT +Status Take(FunctionContext* context, const Datum& values, const Datum& indices, + const TakeOptions& options, Datum* out); + /// \brief BinaryKernel implementing Take operation class ARROW_EXPORT TakeKernel : public BinaryKernel { public: From b64722d13e54d67e8c8105bd6ae77e45a191c80d Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 19 Mar 2019 21:19:48 -0400 Subject: [PATCH 10/14] use NULLPTR in public headers --- cpp/src/arrow/array/builder_binary.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/array/builder_binary.h b/cpp/src/arrow/array/builder_binary.h index be4579cf8f4..954f0bb7d7a 100644 --- a/cpp/src/arrow/array/builder_binary.h +++ b/cpp/src/arrow/array/builder_binary.h @@ -253,7 +253,7 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { BufferBuilder byte_builder_; template - void CheckValueSize(const Sized& s, decltype(s.size())* = nullptr) { + void CheckValueSize(const Sized& s, decltype(s.size())* = NULLPTR) { #ifndef NDEBUG CheckValueSize(static_cast(s.size())); #endif From cc821c61775137e5b23b4e351468ad7268562d7e Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 20 Mar 2019 13:54:24 -0400 Subject: [PATCH 11/14] avoid conflict with macro in R_ext/RS.h --- cpp/src/arrow/compute/kernels/take.cc | 8 ++++---- cpp/src/arrow/compute/kernels/take.h | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/take.cc b/cpp/src/arrow/compute/kernels/take.cc index 69ea4789550..6cff7d2c030 100644 --- a/cpp/src/arrow/compute/kernels/take.cc +++ b/cpp/src/arrow/compute/kernels/take.cc @@ -79,7 +79,7 @@ Status TakeImpl(FunctionContext*, const ValueArray& values, const IndexArray& in continue; } auto index = static_cast(raw_indices[i]); - if (B == TakeOptions::ERROR && (index < 0 || index >= values.length())) { + if (B == TakeOptions::RAISE && (index < 0 || index >= values.length())) { return Status::Invalid("take index out of bounds"); } if (B == TakeOptions::TO_NULL && (index < 0 || index >= values.length())) { @@ -120,8 +120,8 @@ Status TakeImpl(FunctionContext* context, const ValueArray& values, const IndexArray& indices, const TakeOptions& options, OutBuilder* builder) { switch (options.out_of_bounds) { - case TakeOptions::ERROR: - return TakeImpl(context, values, indices, builder); + case TakeOptions::RAISE: + return TakeImpl(context, values, indices, builder); case TakeOptions::TO_NULL: return TakeImpl(context, values, indices, builder); case TakeOptions::UNSAFE: @@ -152,7 +152,7 @@ struct UnpackValues { Status Visit(const NullType& t) { auto indices_length = params_.indices->length(); - if (params_.options.out_of_bounds == TakeOptions::ERROR && indices_length != 0) { + if (params_.options.out_of_bounds == TakeOptions::RAISE && indices_length != 0) { auto indices = static_cast(*params_.indices).raw_values(); auto minmax = std::minmax_element(indices, indices + indices_length); auto min = static_cast(*minmax.first); diff --git a/cpp/src/arrow/compute/kernels/take.h b/cpp/src/arrow/compute/kernels/take.h index ad2ec7004e3..e26bdedd3e5 100644 --- a/cpp/src/arrow/compute/kernels/take.h +++ b/cpp/src/arrow/compute/kernels/take.h @@ -34,14 +34,14 @@ class FunctionContext; struct ARROW_EXPORT TakeOptions { enum OutOfBoundsBehavior { // Out of bounds indices will raise an error - ERROR, + RAISE, // Out of bounds indices will result in a null value TO_NULL, // Bounds checking will be skipped, which is faster. // Only use this if indices are known to be within bounds; // out of bounds indices will result in undefined behavior UNSAFE - } out_of_bounds = ERROR; + } out_of_bounds = RAISE; }; /// \brief Take from an array of values at indices in another array From 3a1ef1285e591d71f6aca125c0662e6cc6c4eab0 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 20 Mar 2019 14:09:24 -0400 Subject: [PATCH 12/14] renaming, remove superflous DCHECK --- cpp/src/arrow/compute/kernels/take.cc | 32 ++++++++++++++------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/take.cc b/cpp/src/arrow/compute/kernels/take.cc index 6cff7d2c030..23273c816b1 100644 --- a/cpp/src/arrow/compute/kernels/take.cc +++ b/cpp/src/arrow/compute/kernels/take.cc @@ -97,8 +97,8 @@ Status TakeImpl(FunctionContext*, const ValueArray& values, const IndexArray& in template -Status TakeImpl(FunctionContext* context, const ValueArray& values, - const IndexArray& indices, OutBuilder* builder) { +Status UnpackIndicesNullCount(FunctionContext* context, const ValueArray& values, + const IndexArray& indices, OutBuilder* builder) { if (indices.null_count() == 0) { return TakeImpl(context, values, indices, builder); } @@ -107,27 +107,28 @@ Status TakeImpl(FunctionContext* context, const ValueArray& values, template -Status TakeImpl(FunctionContext* context, const ValueArray& values, - const IndexArray& indices, OutBuilder* builder) { +Status UnpackValuesNullCount(FunctionContext* context, const ValueArray& values, + const IndexArray& indices, OutBuilder* builder) { if (values.null_count() == 0) { - return TakeImpl(context, values, indices, builder); + return UnpackIndicesNullCount(context, values, indices, builder); } - return TakeImpl(context, values, indices, builder); + return UnpackIndicesNullCount(context, values, indices, builder); } template -Status TakeImpl(FunctionContext* context, const ValueArray& values, - const IndexArray& indices, const TakeOptions& options, - OutBuilder* builder) { +Status UnpackOutOfBoundsBehavior(FunctionContext* context, const ValueArray& values, + const IndexArray& indices, const TakeOptions& options, + OutBuilder* builder) { switch (options.out_of_bounds) { case TakeOptions::RAISE: - return TakeImpl(context, values, indices, builder); + return UnpackValuesNullCount(context, values, indices, builder); case TakeOptions::TO_NULL: - return TakeImpl(context, values, indices, builder); + return UnpackValuesNullCount(context, values, indices, + builder); case TakeOptions::UNSAFE: - return TakeImpl(context, values, indices, builder); + return UnpackValuesNullCount(context, values, indices, + builder); default: - ARROW_LOG(FATAL) << "how did we get here"; return Status::NotImplemented("how did we get here"); } } @@ -145,8 +146,9 @@ struct UnpackValues { std::unique_ptr builder; RETURN_NOT_OK(MakeBuilder(params_.context->memory_pool(), values.type(), &builder)); RETURN_NOT_OK(builder->Reserve(indices.length())); - RETURN_NOT_OK(TakeImpl(params_.context, values, indices, params_.options, - static_cast(builder.get()))); + RETURN_NOT_OK(UnpackOutOfBoundsBehavior(params_.context, values, indices, + params_.options, + static_cast(builder.get()))); return builder->Finish(params_.out); } From 99792b7957970ca4fe949f391749d2052014c061 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 8 Apr 2019 14:10:31 -0400 Subject: [PATCH 13/14] address review comments --- cpp/src/arrow/array/builder_binary.h | 48 +++++++++++----------- cpp/src/arrow/compute/kernels/take-test.cc | 39 ++++++++++-------- cpp/src/arrow/testing/gtest_util.h | 5 --- 3 files changed, 44 insertions(+), 48 deletions(-) diff --git a/cpp/src/arrow/array/builder_binary.h b/cpp/src/arrow/array/builder_binary.h index 954f0bb7d7a..954f58e7cec 100644 --- a/cpp/src/arrow/array/builder_binary.h +++ b/cpp/src/arrow/array/builder_binary.h @@ -185,18 +185,8 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { Status Append(const uint8_t* value) { ARROW_RETURN_NOT_OK(Reserve(1)); - UnsafeAppendToBitmap(true); - return byte_builder_.Append(value, byte_width_); - } - - void UnsafeAppend(const uint8_t* value) { - UnsafeAppendToBitmap(true); - byte_builder_.UnsafeAppend(value, byte_width_); - } - - void UnsafeAppend(util::string_view value) { - CheckValueSize(value); - UnsafeAppend(reinterpret_cast(value.data())); + UnsafeAppend(value); + return Status::OK(); } Status Append(const char* value) { @@ -204,26 +194,41 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { } Status Append(const util::string_view& view) { - CheckValueSize(view); - return Append(reinterpret_cast(view.data())); + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppend(view); + return Status::OK(); } Status Append(const std::string& s) { - CheckValueSize(s); - return Append(reinterpret_cast(s.data())); + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppend(s); + return Status::OK(); } template Status Append(const std::array& value) { ARROW_RETURN_NOT_OK(Reserve(1)); - UnsafeAppendToBitmap(true); - return byte_builder_.Append(value); + UnsafeAppend( + util::string_view(reinterpret_cast(value.data()), value.size())); + return Status::OK(); } Status AppendValues(const uint8_t* data, int64_t length, const uint8_t* valid_bytes = NULLPTR); Status AppendNull(); + void UnsafeAppend(const uint8_t* value) { + UnsafeAppendToBitmap(true); + byte_builder_.UnsafeAppend(value, byte_width_); + } + + void UnsafeAppend(util::string_view value) { +#ifndef NDEBUG + CheckValueSize(static_cast(value.size())); +#endif + UnsafeAppend(reinterpret_cast(value.data())); + } + void UnsafeAppendNull() { UnsafeAppendToBitmap(false); byte_builder_.UnsafeAdvance(byte_width_); @@ -252,13 +257,6 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { int32_t byte_width_; BufferBuilder byte_builder_; - template - void CheckValueSize(const Sized& s, decltype(s.size())* = NULLPTR) { -#ifndef NDEBUG - CheckValueSize(static_cast(s.size())); -#endif - } - #ifndef NDEBUG void CheckValueSize(int64_t size); #endif diff --git a/cpp/src/arrow/compute/kernels/take-test.cc b/cpp/src/arrow/compute/kernels/take-test.cc index 6fd5ec446b7..65ec0c12f12 100644 --- a/cpp/src/arrow/compute/kernels/take-test.cc +++ b/cpp/src/arrow/compute/kernels/take-test.cc @@ -45,14 +45,17 @@ class TestTakeKernel : public ComputeFixture, public TestBase { const std::string& indices, TakeOptions options, const std::string& expected) { std::shared_ptr actual; - ASSERT_OK(this->Take(type, values, indices, options, &actual)); - AssertArraysEqual(*ArrayFromJSON(type, expected), *actual); + + for (auto index_type : {int8(), uint32()}) { + ASSERT_OK(this->Take(type, values, index_type, indices, options, &actual)); + AssertArraysEqual(*ArrayFromJSON(type, expected), *actual); + } } Status Take(const std::shared_ptr& type, const std::string& values, - const std::string& indices, TakeOptions options, - std::shared_ptr* out) { + const std::shared_ptr& index_type, const std::string& indices, + TakeOptions options, std::shared_ptr* out) { return arrow::compute::Take(&this->ctx_, *ArrayFromJSON(type, values), - *ArrayFromJSON(int8(), indices), options, out); + *ArrayFromJSON(index_type, indices), options, out); } }; @@ -69,8 +72,8 @@ TEST_F(TestTakeKernelWithNull, TakeNull) { this->AssertTake("[null, null, null]", "[0, 1, 0]", options, "[null, null, null]"); std::shared_ptr arr; - ASSERT_RAISES(Invalid, - this->Take(null(), "[null, null, null]", "[0, 9, 0]", options, &arr)); + ASSERT_RAISES(Invalid, this->Take(null(), "[null, null, null]", int8(), "[0, 9, 0]", + options, &arr)); } class TestTakeKernelWithBoolean : public TestTakeKernel { @@ -89,8 +92,8 @@ TEST_F(TestTakeKernelWithBoolean, TakeBoolean) { this->AssertTake("[true, false, true]", "[null, 1, 0]", options, "[null, false, true]"); std::shared_ptr arr; - ASSERT_RAISES(Invalid, - this->Take(boolean(), "[true, false, true]", "[0, 9, 0]", options, &arr)); + ASSERT_RAISES(Invalid, this->Take(boolean(), "[true, false, true]", int8(), "[0, 9, 0]", + options, &arr)); options.out_of_bounds = TakeOptions::TO_NULL; this->AssertTake("[true, false, true]", "[0, 9, 0]", options, "[true, null, true]"); @@ -117,8 +120,8 @@ TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) { this->AssertTake("[7, 8, 9]", "[null, 1, 0]", options, "[null, 8, 7]"); std::shared_ptr arr; - ASSERT_RAISES(Invalid, this->Take(this->type_singleton(), "[7, 8, 9]", "[0, 9, 0]", - options, &arr)); + ASSERT_RAISES(Invalid, this->Take(this->type_singleton(), "[7, 8, 9]", int8(), + "[0, 9, 0]", options, &arr)); options.out_of_bounds = TakeOptions::TO_NULL; this->AssertTake("[7, 8, 9]", "[0, 9, 0]", options, "[7, null, 7]"); @@ -152,8 +155,8 @@ TEST_F(TestTakeKernelWithString, TakeString) { this->AssertTake(R"(["a", "b", "c"])", "[null, 1, 0]", options, R"([null, "b", "a"])"); std::shared_ptr arr; - ASSERT_RAISES(Invalid, - this->Take(utf8(), R"(["a", "b", "c"])", "[0, 9, 0]", options, &arr)); + ASSERT_RAISES(Invalid, this->Take(utf8(), R"(["a", "b", "c"])", int8(), "[0, 9, 0]", + options, &arr)); options.out_of_bounds = TakeOptions::TO_NULL; this->AssertTake(R"(["a", "b", "c"])", "[0, 9, 0]", options, R"(["a", null, "a"])"); @@ -162,13 +165,13 @@ TEST_F(TestTakeKernelWithString, TakeString) { TEST_F(TestTakeKernelWithString, TakeDictionary) { TakeOptions options; auto dict = R"(["a", "b", "c", "d", "e"])"; - this->AssertTakeDictionary(dict, "[0, 1, 4]", "[0, 1, 0]", options, "[0, 1, 0]"); - this->AssertTakeDictionary(dict, "[null, 1, 4]", "[0, 1, 0]", options, - "[null, 1, null]"); - this->AssertTakeDictionary(dict, "[0, 1, 4]", "[null, 1, 0]", options, "[null, 1, 0]"); + this->AssertTakeDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", options, "[3, 4, 3]"); + this->AssertTakeDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", options, + "[null, 4, null]"); + this->AssertTakeDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", options, "[null, 4, 3]"); options.out_of_bounds = TakeOptions::TO_NULL; - this->AssertTakeDictionary(dict, "[0, 1, 4]", "[0, 9, 0]", options, "[0, null, 0]"); + this->AssertTakeDictionary(dict, "[3, 4, 2]", "[0, 9, 0]", options, "[3, null, 3]"); } } // namespace compute diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index e9aa5b80ed5..88f3d129c29 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -100,11 +100,6 @@ namespace arrow { typedef ::testing::Types NumericArrowTypes; -typedef ::testing::Types - ParameterFreeArrowTypes; class ChunkedArray; class Column; From a0250e5bad75c05b264c56ff0456fdd6c6c818dd Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 9 Apr 2019 10:00:58 -0400 Subject: [PATCH 14/14] Remove out of bounds option- always raise an error --- cpp/src/arrow/compute/kernels/take-test.cc | 12 ------ cpp/src/arrow/compute/kernels/take.cc | 50 ++++++---------------- cpp/src/arrow/compute/kernels/take.h | 13 +----- 3 files changed, 14 insertions(+), 61 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/take-test.cc b/cpp/src/arrow/compute/kernels/take-test.cc index 65ec0c12f12..110e07375a9 100644 --- a/cpp/src/arrow/compute/kernels/take-test.cc +++ b/cpp/src/arrow/compute/kernels/take-test.cc @@ -94,9 +94,6 @@ TEST_F(TestTakeKernelWithBoolean, TakeBoolean) { std::shared_ptr arr; ASSERT_RAISES(Invalid, this->Take(boolean(), "[true, false, true]", int8(), "[0, 9, 0]", options, &arr)); - - options.out_of_bounds = TakeOptions::TO_NULL; - this->AssertTake("[true, false, true]", "[0, 9, 0]", options, "[true, null, true]"); } template @@ -122,9 +119,6 @@ TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) { std::shared_ptr arr; ASSERT_RAISES(Invalid, this->Take(this->type_singleton(), "[7, 8, 9]", int8(), "[0, 9, 0]", options, &arr)); - - options.out_of_bounds = TakeOptions::TO_NULL; - this->AssertTake("[7, 8, 9]", "[0, 9, 0]", options, "[7, null, 7]"); } class TestTakeKernelWithString : public TestTakeKernel { @@ -157,9 +151,6 @@ TEST_F(TestTakeKernelWithString, TakeString) { std::shared_ptr arr; ASSERT_RAISES(Invalid, this->Take(utf8(), R"(["a", "b", "c"])", int8(), "[0, 9, 0]", options, &arr)); - - options.out_of_bounds = TakeOptions::TO_NULL; - this->AssertTake(R"(["a", "b", "c"])", "[0, 9, 0]", options, R"(["a", null, "a"])"); } TEST_F(TestTakeKernelWithString, TakeDictionary) { @@ -169,9 +160,6 @@ TEST_F(TestTakeKernelWithString, TakeDictionary) { this->AssertTakeDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", options, "[null, 4, null]"); this->AssertTakeDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", options, "[null, 4, 3]"); - - options.out_of_bounds = TakeOptions::TO_NULL; - this->AssertTakeDictionary(dict, "[3, 4, 2]", "[0, 9, 0]", options, "[3, null, 3]"); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/take.cc b/cpp/src/arrow/compute/kernels/take.cc index 23273c816b1..1dd34a92449 100644 --- a/cpp/src/arrow/compute/kernels/take.cc +++ b/cpp/src/arrow/compute/kernels/take.cc @@ -68,8 +68,8 @@ Status UnsafeAppend(StringBuilder* builder, util::string_view value) { return Status::OK(); } -template +template Status TakeImpl(FunctionContext*, const ValueArray& values, const IndexArray& indices, OutBuilder* builder) { auto raw_indices = indices.raw_values(); @@ -79,13 +79,9 @@ Status TakeImpl(FunctionContext*, const ValueArray& values, const IndexArray& in continue; } auto index = static_cast(raw_indices[i]); - if (B == TakeOptions::RAISE && (index < 0 || index >= values.length())) { + if (index < 0 || index >= values.length()) { return Status::Invalid("take index out of bounds"); } - if (B == TakeOptions::TO_NULL && (index < 0 || index >= values.length())) { - builder->UnsafeAppendNull(); - continue; - } if (!AllValuesValid && values.IsNull(index)) { builder->UnsafeAppendNull(); continue; @@ -95,42 +91,23 @@ Status TakeImpl(FunctionContext*, const ValueArray& values, const IndexArray& in return Status::OK(); } -template +template Status UnpackIndicesNullCount(FunctionContext* context, const ValueArray& values, const IndexArray& indices, OutBuilder* builder) { if (indices.null_count() == 0) { - return TakeImpl(context, values, indices, builder); + return TakeImpl(context, values, indices, builder); } - return TakeImpl(context, values, indices, builder); + return TakeImpl(context, values, indices, builder); } -template +template Status UnpackValuesNullCount(FunctionContext* context, const ValueArray& values, const IndexArray& indices, OutBuilder* builder) { if (values.null_count() == 0) { - return UnpackIndicesNullCount(context, values, indices, builder); - } - return UnpackIndicesNullCount(context, values, indices, builder); -} - -template -Status UnpackOutOfBoundsBehavior(FunctionContext* context, const ValueArray& values, - const IndexArray& indices, const TakeOptions& options, - OutBuilder* builder) { - switch (options.out_of_bounds) { - case TakeOptions::RAISE: - return UnpackValuesNullCount(context, values, indices, builder); - case TakeOptions::TO_NULL: - return UnpackValuesNullCount(context, values, indices, - builder); - case TakeOptions::UNSAFE: - return UnpackValuesNullCount(context, values, indices, - builder); - default: - return Status::NotImplemented("how did we get here"); + return UnpackIndicesNullCount(context, values, indices, builder); } + return UnpackIndicesNullCount(context, values, indices, builder); } template @@ -146,15 +123,14 @@ struct UnpackValues { std::unique_ptr builder; RETURN_NOT_OK(MakeBuilder(params_.context->memory_pool(), values.type(), &builder)); RETURN_NOT_OK(builder->Reserve(indices.length())); - RETURN_NOT_OK(UnpackOutOfBoundsBehavior(params_.context, values, indices, - params_.options, - static_cast(builder.get()))); + RETURN_NOT_OK(UnpackValuesNullCount(params_.context, values, indices, + static_cast(builder.get()))); return builder->Finish(params_.out); } Status Visit(const NullType& t) { auto indices_length = params_.indices->length(); - if (params_.options.out_of_bounds == TakeOptions::RAISE && indices_length != 0) { + if (indices_length != 0) { auto indices = static_cast(*params_.indices).raw_values(); auto minmax = std::minmax_element(indices, indices + indices_length); auto min = static_cast(*minmax.first); diff --git a/cpp/src/arrow/compute/kernels/take.h b/cpp/src/arrow/compute/kernels/take.h index e26bdedd3e5..bfd69112786 100644 --- a/cpp/src/arrow/compute/kernels/take.h +++ b/cpp/src/arrow/compute/kernels/take.h @@ -31,18 +31,7 @@ namespace compute { class FunctionContext; -struct ARROW_EXPORT TakeOptions { - enum OutOfBoundsBehavior { - // Out of bounds indices will raise an error - RAISE, - // Out of bounds indices will result in a null value - TO_NULL, - // Bounds checking will be skipped, which is faster. - // Only use this if indices are known to be within bounds; - // out of bounds indices will result in undefined behavior - UNSAFE - } out_of_bounds = RAISE; -}; +struct ARROW_EXPORT TakeOptions {}; /// \brief Take from an array of values at indices in another array ///