diff --git a/cpp/src/arrow/compute/kernels/filter.cc b/cpp/src/arrow/compute/kernels/filter.cc index 2a3a24f0648..1b2ba31613a 100644 --- a/cpp/src/arrow/compute/kernels/filter.cc +++ b/cpp/src/arrow/compute/kernels/filter.cc @@ -22,6 +22,7 @@ #include #include +#include "arrow/array/concatenate.h" #include "arrow/builder.h" #include "arrow/compute/kernels/take_internal.h" #include "arrow/record_batch.h" @@ -163,5 +164,91 @@ Status Filter(FunctionContext* ctx, const RecordBatch& batch, const Array& filte return Status::OK(); } +Status Filter(FunctionContext* ctx, const ChunkedArray& values, const Array& filter, + std::shared_ptr* out) { + if (values.length() != filter.length()) { + return Status::Invalid("filter and value array must have identical lengths"); + } + auto num_chunks = values.num_chunks(); + std::vector> new_chunks(num_chunks); + std::shared_ptr current_chunk; + int64_t offset = 0; + int64_t len; + + for (int i = 0; i < num_chunks; i++) { + current_chunk = values.chunk(i); + len = current_chunk->length(); + RETURN_NOT_OK( + Filter(ctx, *current_chunk, *filter.Slice(offset, len), &new_chunks[i])); + offset += len; + } + + *out = std::make_shared(std::move(new_chunks)); + return Status::OK(); +} + +Status Filter(FunctionContext* ctx, const ChunkedArray& values, + const ChunkedArray& filter, std::shared_ptr* out) { + if (values.length() != filter.length()) { + return Status::Invalid("filter and value array must have identical lengths"); + } + auto num_chunks = values.num_chunks(); + std::vector> new_chunks(num_chunks); + std::shared_ptr current_chunk; + std::shared_ptr current_chunked_filter; + std::shared_ptr current_filter; + int64_t offset = 0; + int64_t len; + + for (int i = 0; i < num_chunks; i++) { + current_chunk = values.chunk(i); + len = current_chunk->length(); + if (len > 0) { + current_chunked_filter = filter.Slice(offset, len); + if (current_chunked_filter->num_chunks() == 1) { + current_filter = current_chunked_filter->chunk(0); + } else { + // Concatenate the chunks of the filter so we have an Array + RETURN_NOT_OK(Concatenate(current_chunked_filter->chunks(), default_memory_pool(), + ¤t_filter)); + } + RETURN_NOT_OK(Filter(ctx, *current_chunk, *current_filter, &new_chunks[i])); + offset += len; + } else { + // Put a zero length array there, which we know our current chunk to be + new_chunks[i] = current_chunk; + } + } + + *out = std::make_shared(std::move(new_chunks)); + return Status::OK(); +} + +Status Filter(FunctionContext* ctx, const Table& table, const Array& filter, + std::shared_ptr* out) { + auto ncols = table.num_columns(); + + std::vector> columns(ncols); + + for (int j = 0; j < ncols; j++) { + RETURN_NOT_OK(Filter(ctx, *table.column(j), filter, &columns[j])); + } + *out = Table::Make(table.schema(), columns); + return Status::OK(); +} + +Status Filter(FunctionContext* ctx, const Table& table, const ChunkedArray& filter, + std::shared_ptr
* out) { + auto ncols = table.num_columns(); + + std::vector> columns(ncols); + + for (int j = 0; j < ncols; j++) { + RETURN_NOT_OK(Filter(ctx, *table.column(j), filter, &columns[j])); + } + *out = Table::Make(table.schema(), columns); + return Status::OK(); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/filter.h b/cpp/src/arrow/compute/kernels/filter.h index f78c2e46df7..bc7f75db539 100644 --- a/cpp/src/arrow/compute/kernels/filter.h +++ b/cpp/src/arrow/compute/kernels/filter.h @@ -50,6 +50,44 @@ ARROW_EXPORT Status Filter(FunctionContext* ctx, const Array& values, const Array& filter, std::shared_ptr* out); +/// \brief Filter a chunked array with a boolean selection filter +/// +/// The output chunked array will be populated with values from the input at positions +/// where the selection filter is not 0. Nulls in the filter will result in nulls +/// in the output. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// filter = [0, 1, 1, 0, null, 1], the output will be +/// = ["b", "c", null, "f"] +/// +/// \param[in] ctx the FunctionContext +/// \param[in] values chunked array to filter +/// \param[in] filter indicates which values should be filtered out +/// \param[out] out resulting chunked array +/// NOTE: Experimental API +ARROW_EXPORT +Status Filter(FunctionContext* ctx, const ChunkedArray& values, const Array& filter, + std::shared_ptr* out); + +/// \brief Filter a chunked array with a boolean selection filter +/// +/// The output chunked array will be populated with values from the input at positions +/// where the selection filter is not 0. Nulls in the filter will result in nulls +/// in the output. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// filter = [0, 1, 1, 0, null, 1], the output will be +/// = ["b", "c", null, "f"] +/// +/// \param[in] ctx the FunctionContext +/// \param[in] values chunked array to filter +/// \param[in] filter indicates which values should be filtered out +/// \param[out] out resulting chunked array +/// NOTE: Experimental API +ARROW_EXPORT +Status Filter(FunctionContext* ctx, const ChunkedArray& values, + const ChunkedArray& filter, std::shared_ptr* out); + /// \brief Filter a record batch with a boolean selection filter /// /// The output record batch's columns will be populated with values from corresponding @@ -60,10 +98,41 @@ Status Filter(FunctionContext* ctx, const Array& values, const Array& filter, /// \param[in] batch record batch to filter /// \param[in] filter indicates which values should be filtered out /// \param[out] out resulting record batch +/// NOTE: Experimental API ARROW_EXPORT Status Filter(FunctionContext* ctx, const RecordBatch& batch, const Array& filter, std::shared_ptr* out); +/// \brief Filter a table with a boolean selection filter +/// +/// The output table's columns will be populated with values from corresponding +/// columns of the input at positions where the selection filter is not 0. Nulls in the +/// filter will result in nulls in each column of the output. +/// +/// \param[in] ctx the FunctionContext +/// \param[in] table table to filter +/// \param[in] filter indicates which values should be filtered out +/// \param[out] out resulting table +/// NOTE: Experimental API +ARROW_EXPORT +Status Filter(FunctionContext* ctx, const Table& table, const Array& filter, + std::shared_ptr
* out); + +/// \brief Filter a table with a boolean selection filter +/// +/// The output record batch's columns will be populated with values from corresponding +/// columns of the input at positions where the selection filter is not 0. Nulls in the +/// filter will result in nulls in the output. +/// +/// \param[in] ctx the FunctionContext +/// \param[in] table record batch to filter +/// \param[in] filter indicates which values should be filtered out +/// \param[out] out resulting record batch +/// NOTE: Experimental API +ARROW_EXPORT +Status Filter(FunctionContext* ctx, const Table& table, const ChunkedArray& filter, + std::shared_ptr
* out); + /// \brief Filter an array with a boolean selection filter /// /// \param[in] ctx the FunctionContext diff --git a/cpp/src/arrow/compute/kernels/filter_test.cc b/cpp/src/arrow/compute/kernels/filter_test.cc index bb685f474dc..0c00ce4ca1e 100644 --- a/cpp/src/arrow/compute/kernels/filter_test.cc +++ b/cpp/src/arrow/compute/kernels/filter_test.cc @@ -468,5 +468,157 @@ TEST_F(TestFilterKernelWithUnion, FilterUnion) { } } +class TestFilterKernelWithRecordBatch : public TestFilterKernel { + public: + void AssertFilter(const std::shared_ptr& schm, const std::string& batch_json, + const std::string& selection, const std::string& expected_batch) { + std::shared_ptr actual; + + ASSERT_OK(this->Filter(schm, batch_json, selection, &actual)); + ASSERT_OK(actual->Validate()); + ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual); + } + + Status Filter(const std::shared_ptr& schm, const std::string& batch_json, + const std::string& selection, std::shared_ptr* out) { + auto batch = RecordBatchFromJSON(schm, batch_json); + return arrow::compute::Filter(&this->ctx_, *batch, + *ArrayFromJSON(boolean(), selection), out); + } +}; + +TEST_F(TestFilterKernelWithRecordBatch, FilterRecordBatch) { + std::vector> fields = {field("a", int32()), field("b", utf8())}; + auto schm = schema(fields); + + auto struct_json = R"([ + {"a": null, "b": "yo"}, + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ])"; + this->AssertFilter(schm, struct_json, "[0, 0, 0, 0]", "[]"); + this->AssertFilter(schm, struct_json, "[0, 1, 1, null]", R"([ + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": null, "b": null} + ])"); + this->AssertFilter(schm, struct_json, "[1, 1, 1, 1]", struct_json); + this->AssertFilter(schm, struct_json, "[1, 0, 1, 0]", R"([ + {"a": null, "b": "yo"}, + {"a": 2, "b": "hello"} + ])"); +} + +class TestFilterKernelWithChunkedArray : public TestFilterKernel { + public: + void AssertFilter(const std::shared_ptr& type, + const std::vector& values, const std::string& filter, + const std::vector& expected) { + std::shared_ptr actual; + ASSERT_OK(this->FilterWithArray(type, values, filter, &actual)); + ASSERT_OK(actual->Validate()); + AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual); + } + + void AssertChunkedFilter(const std::shared_ptr& type, + const std::vector& values, + const std::vector& filter, + const std::vector& expected) { + std::shared_ptr actual; + ASSERT_OK(this->FilterWithChunkedArray(type, values, filter, &actual)); + ASSERT_OK(actual->Validate()); + AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual); + } + + Status FilterWithArray(const std::shared_ptr& type, + const std::vector& values, + const std::string& filter, std::shared_ptr* out) { + return arrow::compute::Filter(&this->ctx_, *ChunkedArrayFromJSON(type, values), + *ArrayFromJSON(boolean(), filter), out); + } + + Status FilterWithChunkedArray(const std::shared_ptr& type, + const std::vector& values, + const std::vector& filter, + std::shared_ptr* out) { + return arrow::compute::Filter(&this->ctx_, *ChunkedArrayFromJSON(type, values), + *ChunkedArrayFromJSON(boolean(), filter), out); + } +}; + +TEST_F(TestFilterKernelWithChunkedArray, FilterChunkedArray) { + this->AssertFilter(int8(), {"[]"}, "[]", {"[]"}); + this->AssertChunkedFilter(int8(), {"[]"}, {"[]"}, {"[]"}); + + this->AssertFilter(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0]", {"[]", "[8]"}); + this->AssertChunkedFilter(int8(), {"[7]", "[8, 9]"}, {"[0]", "[1, 0]"}, {"[]", "[8]"}); + this->AssertChunkedFilter(int8(), {"[7]", "[8, 9]"}, {"[0, 1]", "[0]"}, {"[8]", "[]"}); + + std::shared_ptr arr; + ASSERT_RAISES( + Invalid, this->FilterWithArray(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0, 1, 1]", &arr)); + ASSERT_RAISES(Invalid, this->FilterWithChunkedArray(int8(), {"[7]", "[8, 9]"}, + {"[0, 1, 0]", "[1, 1]"}, &arr)); +} + +class TestFilterKernelWithTable : public TestFilterKernel
{ + public: + void AssertFilter(const std::shared_ptr& schm, + const std::vector& table_json, const std::string& filter, + const std::vector& expected_table) { + std::shared_ptr
actual; + + ASSERT_OK(this->FilterWithArray(schm, table_json, filter, &actual)); + ASSERT_OK(actual->Validate()); + ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual); + } + + void AssertChunkedFilter(const std::shared_ptr& schm, + const std::vector& table_json, + const std::vector& filter, + const std::vector& expected_table) { + std::shared_ptr
actual; + + ASSERT_OK(this->FilterWithChunkedArray(schm, table_json, filter, &actual)); + ASSERT_OK(actual->Validate()); + ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual); + } + + Status FilterWithArray(const std::shared_ptr& schm, + const std::vector& values, + const std::string& filter, std::shared_ptr
* out) { + return arrow::compute::Filter(&this->ctx_, *TableFromJSON(schm, values), + *ArrayFromJSON(boolean(), filter), out); + } + + Status FilterWithChunkedArray(const std::shared_ptr& schm, + const std::vector& values, + const std::vector& filter, + std::shared_ptr
* out) { + return arrow::compute::Filter(&this->ctx_, *TableFromJSON(schm, values), + *ChunkedArrayFromJSON(boolean(), filter), out); + } +}; + +TEST_F(TestFilterKernelWithTable, FilterTable) { + std::vector> fields = {field("a", int32()), field("b", utf8())}; + auto schm = schema(fields); + + std::vector table_json = { + "[{\"a\": null, \"b\": \"yo\"},{\"a\": 1, \"b\": \"\"}]", + "[{\"a\": 2, \"b\": \"hello\"},{\"a\": 4, \"b\": \"eh\"}]"}; + this->AssertFilter(schm, table_json, "[0, 0, 0, 0]", {"[]", "[]"}); + this->AssertChunkedFilter(schm, table_json, {"[0]", "[0, 0, 0]"}, {"[]", "[]"}); + + std::vector expected2 = { + "[{\"a\": 1, \"b\": \"\"}]", + "[{\"a\": 2, \"b\": \"hello\"},{\"a\": null, \"b\": null}]"}; + this->AssertFilter(schm, table_json, "[0, 1, 1, null]", expected2); + this->AssertChunkedFilter(schm, table_json, {"[0, 1, 1]", "[null]"}, expected2); + this->AssertFilter(schm, table_json, "[1, 1, 1, 1]", table_json); + this->AssertChunkedFilter(schm, table_json, {"[1]", "[1, 1, 1]"}, table_json); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/take.cc b/cpp/src/arrow/compute/kernels/take.cc index a5210d79312..2fa860c7682 100644 --- a/cpp/src/arrow/compute/kernels/take.cc +++ b/cpp/src/arrow/compute/kernels/take.cc @@ -18,7 +18,9 @@ #include #include #include +#include +#include "arrow/array/concatenate.h" #include "arrow/compute/kernels/take.h" #include "arrow/compute/kernels/take_internal.h" #include "arrow/util/logging.h" @@ -99,5 +101,99 @@ Status Take(FunctionContext* ctx, const Datum& values, const Datum& indices, return kernel->Call(ctx, values, indices, out); } +Status Take(FunctionContext* ctx, const ChunkedArray& values, const Array& indices, + const TakeOptions& options, std::shared_ptr* out) { + auto num_chunks = values.num_chunks(); + std::vector> new_chunks(1); // Hard-coded 1 for now + std::shared_ptr current_chunk; + + // Case 1: `values` has a single chunk, so just use it + if (num_chunks == 1) { + current_chunk = values.chunk(0); + } else { + // TODO Case 2: See if all `indices` fall in the same chunk and call Array Take on it + // See + // https://github.com/apache/arrow/blob/6f2c9041137001f7a9212f244b51bc004efc29af/r/src/compute.cpp#L123-L151 + // TODO Case 3: If indices are sorted, can slice them and call Array Take + + // Case 4: Else, concatenate chunks and call Array Take + RETURN_NOT_OK(Concatenate(values.chunks(), default_memory_pool(), ¤t_chunk)); + } + // Call Array Take on our single chunk + RETURN_NOT_OK(Take(ctx, *current_chunk, indices, options, &new_chunks[0])); + *out = std::make_shared(std::move(new_chunks)); + return Status::OK(); +} + +Status Take(FunctionContext* ctx, const ChunkedArray& values, const ChunkedArray& indices, + const TakeOptions& options, std::shared_ptr* out) { + auto num_chunks = indices.num_chunks(); + std::vector> new_chunks(num_chunks); + std::shared_ptr current_chunk; + + for (int i = 0; i < num_chunks; i++) { + // Take with that indices chunk + // Note that as currently implemented, this is inefficient because `values` + // will get concatenated on every iteration of this loop + RETURN_NOT_OK(Take(ctx, values, *indices.chunk(i), options, ¤t_chunk)); + // Concatenate the result to make a single array for this chunk + RETURN_NOT_OK( + Concatenate(current_chunk->chunks(), default_memory_pool(), &new_chunks[i])); + } + *out = std::make_shared(std::move(new_chunks)); + return Status::OK(); +} + +Status Take(FunctionContext* ctx, const Array& values, const ChunkedArray& indices, + const TakeOptions& options, std::shared_ptr* out) { + auto num_chunks = indices.num_chunks(); + std::vector> new_chunks(num_chunks); + + for (int i = 0; i < num_chunks; i++) { + // Take with that indices chunk + RETURN_NOT_OK(Take(ctx, values, *indices.chunk(i), options, &new_chunks[i])); + } + *out = std::make_shared(std::move(new_chunks)); + return Status::OK(); +} + +Status Take(FunctionContext* ctx, const RecordBatch& batch, const Array& indices, + const TakeOptions& options, std::shared_ptr* out) { + auto ncols = batch.num_columns(); + auto nrows = indices.length(); + + std::vector> columns(ncols); + + for (int j = 0; j < ncols; j++) { + RETURN_NOT_OK(Take(ctx, *batch.column(j), indices, options, &columns[j])); + } + *out = RecordBatch::Make(batch.schema(), nrows, columns); + return Status::OK(); +} + +Status Take(FunctionContext* ctx, const Table& table, const Array& indices, + const TakeOptions& options, std::shared_ptr
* out) { + auto ncols = table.num_columns(); + std::vector> columns(ncols); + + for (int j = 0; j < ncols; j++) { + RETURN_NOT_OK(Take(ctx, *table.column(j), indices, options, &columns[j])); + } + *out = Table::Make(table.schema(), columns); + return Status::OK(); +} + +Status Take(FunctionContext* ctx, const Table& table, const ChunkedArray& indices, + const TakeOptions& options, std::shared_ptr
* out) { + auto ncols = table.num_columns(); + std::vector> columns(ncols); + + for (int j = 0; j < ncols; j++) { + RETURN_NOT_OK(Take(ctx, *table.column(j), indices, options, &columns[j])); + } + *out = Table::Make(table.schema(), columns); + return Status::OK(); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/take.h b/cpp/src/arrow/compute/kernels/take.h index f064b7265ed..26302b3bc5c 100644 --- a/cpp/src/arrow/compute/kernels/take.h +++ b/cpp/src/arrow/compute/kernels/take.h @@ -53,6 +53,119 @@ ARROW_EXPORT Status Take(FunctionContext* ctx, const Array& values, const Array& indices, const TakeOptions& options, std::shared_ptr* out); +/// \brief Take from a chunked array of values at indices in another array +/// +/// The output chunked array will be of the same type as the input values +/// array, with elements taken from the values array at the given +/// indices. If an index is null then the taken element will be null. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// indices = [2, 1, null, 3], the output will be +/// = [values[2], values[1], null, values[3]] +/// = ["c", "b", null, null] +/// +/// \param[in] ctx the FunctionContext +/// \param[in] values chunked array from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[out] out resulting chunked array +/// NOTE: Experimental API +ARROW_EXPORT +Status Take(FunctionContext* ctx, const ChunkedArray& values, const Array& indices, + const TakeOptions& options, std::shared_ptr* out); + +/// \brief Take from a chunked array of values at indices in a chunked array +/// +/// The output chunked array will be of the same type as the input values +/// array, with elements taken from the values array at the given +/// indices. If an index is null then the taken element will be null. +/// The chunks in the output array will align with the chunks in the indices. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// indices = [2, 1, null, 3], the output will be +/// = [values[2], values[1], null, values[3]] +/// = ["c", "b", null, null] +/// +/// \param[in] ctx the FunctionContext +/// \param[in] values chunked array from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[out] out resulting chunked array +/// NOTE: Experimental API +ARROW_EXPORT +Status Take(FunctionContext* ctx, const ChunkedArray& values, const ChunkedArray& indices, + const TakeOptions& options, std::shared_ptr* out); + +/// \brief Take from an array of values at indices in a chunked array +/// +/// The output chunked array will be of the same type as the input values +/// array, with elements taken from the values array at the given +/// indices. If an index is null then the taken element will be null. +/// The chunks in the output array will align with the chunks in the indices. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// indices = [2, 1, null, 3], the output will be +/// = [values[2], values[1], null, values[3]] +/// = ["c", "b", null, null] +/// +/// \param[in] ctx the FunctionContext +/// \param[in] values array from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[out] out resulting chunked array +/// NOTE: Experimental API +ARROW_EXPORT +Status Take(FunctionContext* ctx, const Array& values, const ChunkedArray& indices, + const TakeOptions& options, std::shared_ptr* out); + +/// \brief Take from a record batch at indices in another array +/// +/// The output batch will have the same schema as the input batch, +/// with rows taken from the columns in the batch at the given +/// indices. If an index is null then the taken element will be null. +/// +/// \param[in] ctx the FunctionContext +/// \param[in] batch record batch from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[out] out resulting record batch +/// NOTE: Experimental API +ARROW_EXPORT +Status Take(FunctionContext* ctx, const RecordBatch& batch, const Array& indices, + const TakeOptions& options, std::shared_ptr* out); + +/// \brief Take from a table at indices in an array +/// +/// The output table will have the same schema as the input table, +/// with rows taken from the columns in the table at the given +/// indices. If an index is null then the taken element will be null. +/// +/// \param[in] ctx the FunctionContext +/// \param[in] table table from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[out] out resulting table +/// NOTE: Experimental API +ARROW_EXPORT +Status Take(FunctionContext* ctx, const Table& table, const Array& indices, + const TakeOptions& options, std::shared_ptr
* out); + +/// \brief Take from a table at indices in a chunked array +/// +/// The output table will have the same schema as the input table, +/// with rows taken from the values array at the given +/// indices. If an index is null then the taken element will be null. +/// +/// \param[in] ctx the FunctionContext +/// \param[in] table table from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[out] out resulting table +/// NOTE: Experimental API +ARROW_EXPORT +Status Take(FunctionContext* ctx, const Table& table, const ChunkedArray& indices, + const TakeOptions& options, std::shared_ptr
* out); + /// \brief Take from an array of values at indices in another array /// /// \param[in] ctx the FunctionContext diff --git a/cpp/src/arrow/compute/kernels/take_test.cc b/cpp/src/arrow/compute/kernels/take_test.cc index 9cd689f601d..c886a00ead9 100644 --- a/cpp/src/arrow/compute/kernels/take_test.cc +++ b/cpp/src/arrow/compute/kernels/take_test.cc @@ -527,5 +527,174 @@ TEST_F(TestPermutationsWithTake, InvertPermutation) { } } +class TestTakeKernelWithRecordBatch : public TestTakeKernel { + public: + void AssertTake(const std::shared_ptr& schm, const std::string& batch_json, + const std::string& indices, const std::string& expected_batch) { + std::shared_ptr actual; + + for (auto index_type : {int8(), uint32()}) { + ASSERT_OK(this->Take(schm, batch_json, index_type, indices, &actual)); + ASSERT_OK(actual->Validate()); + ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual); + } + } + + Status Take(const std::shared_ptr& schm, const std::string& batch_json, + const std::shared_ptr& index_type, const std::string& indices, + std::shared_ptr* out) { + auto batch = RecordBatchFromJSON(schm, batch_json); + TakeOptions options; + return arrow::compute::Take(&this->ctx_, *batch, *ArrayFromJSON(index_type, indices), + options, out); + } +}; + +TEST_F(TestTakeKernelWithRecordBatch, TakeRecordBatch) { + std::vector> fields = {field("a", int32()), field("b", utf8())}; + auto schm = schema(fields); + + auto struct_json = R"([ + {"a": null, "b": "yo"}, + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ])"; + this->AssertTake(schm, struct_json, "[]", "[]"); + this->AssertTake(schm, struct_json, "[3, 1, 3, 1, 3]", R"([ + {"a": 4, "b": "eh"}, + {"a": 1, "b": ""}, + {"a": 4, "b": "eh"}, + {"a": 1, "b": ""}, + {"a": 4, "b": "eh"} + ])"); + this->AssertTake(schm, struct_json, "[3, 1, 0]", R"([ + {"a": 4, "b": "eh"}, + {"a": 1, "b": ""}, + {"a": null, "b": "yo"} + ])"); + this->AssertTake(schm, struct_json, "[0, 1, 2, 3]", struct_json); + this->AssertTake(schm, struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([ + {"a": null, "b": "yo"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"} + ])"); +} + +class TestTakeKernelWithChunkedArray : public TestTakeKernel { + public: + void AssertTake(const std::shared_ptr& type, + const std::vector& values, const std::string& indices, + const std::vector& expected) { + std::shared_ptr actual; + ASSERT_OK(this->TakeWithArray(type, values, indices, &actual)); + ASSERT_OK(actual->Validate()); + AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual); + } + + void AssertChunkedTake(const std::shared_ptr& type, + const std::vector& values, + const std::vector& indices, + const std::vector& expected) { + std::shared_ptr actual; + ASSERT_OK(this->TakeWithChunkedArray(type, values, indices, &actual)); + ASSERT_OK(actual->Validate()); + AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual); + } + + Status TakeWithArray(const std::shared_ptr& type, + const std::vector& values, const std::string& indices, + std::shared_ptr* out) { + TakeOptions options; + return arrow::compute::Take(&this->ctx_, *ChunkedArrayFromJSON(type, values), + *ArrayFromJSON(int8(), indices), options, out); + } + + Status TakeWithChunkedArray(const std::shared_ptr& type, + const std::vector& values, + const std::vector& indices, + std::shared_ptr* out) { + TakeOptions options; + return arrow::compute::Take(&this->ctx_, *ChunkedArrayFromJSON(type, values), + *ChunkedArrayFromJSON(int8(), indices), options, out); + } +}; + +TEST_F(TestTakeKernelWithChunkedArray, TakeChunkedArray) { + this->AssertTake(int8(), {"[]"}, "[]", {"[]"}); + this->AssertChunkedTake(int8(), {"[]"}, {"[]"}, {"[]"}); + + this->AssertTake(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0, 2]", {"[7, 8, 7, 9]"}); + this->AssertChunkedTake(int8(), {"[7]", "[8, 9]"}, {"[0, 1, 0]", "[]", "[2]"}, + {"[7, 8, 7]", "[]", "[9]"}); + this->AssertTake(int8(), {"[7]", "[8, 9]"}, "[2, 1]", {"[9, 8]"}); + + std::shared_ptr arr; + ASSERT_RAISES(IndexError, + this->TakeWithArray(int8(), {"[7]", "[8, 9]"}, "[0, 5]", &arr)); + ASSERT_RAISES(IndexError, this->TakeWithChunkedArray(int8(), {"[7]", "[8, 9]"}, + {"[0, 1, 0]", "[5, 1]"}, &arr)); +} + +class TestTakeKernelWithTable : public TestTakeKernel
{ + public: + void AssertTake(const std::shared_ptr& schm, + const std::vector& table_json, const std::string& filter, + const std::vector& expected_table) { + std::shared_ptr
actual; + + ASSERT_OK(this->TakeWithArray(schm, table_json, filter, &actual)); + ASSERT_OK(actual->Validate()); + ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual); + } + + void AssertChunkedTake(const std::shared_ptr& schm, + const std::vector& table_json, + const std::vector& filter, + const std::vector& expected_table) { + std::shared_ptr
actual; + + ASSERT_OK(this->TakeWithChunkedArray(schm, table_json, filter, &actual)); + ASSERT_OK(actual->Validate()); + ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual); + } + + Status TakeWithArray(const std::shared_ptr& schm, + const std::vector& values, const std::string& indices, + std::shared_ptr
* out) { + TakeOptions options; + return arrow::compute::Take(&this->ctx_, *TableFromJSON(schm, values), + *ArrayFromJSON(int8(), indices), options, out); + } + + Status TakeWithChunkedArray(const std::shared_ptr& schm, + const std::vector& values, + const std::vector& indices, + std::shared_ptr
* out) { + TakeOptions options; + return arrow::compute::Take(&this->ctx_, *TableFromJSON(schm, values), + *ChunkedArrayFromJSON(int8(), indices), options, out); + } +}; + +TEST_F(TestTakeKernelWithTable, TakeTable) { + std::vector> fields = {field("a", int32()), field("b", utf8())}; + auto schm = schema(fields); + + std::vector table_json = { + "[{\"a\": null, \"b\": \"yo\"},{\"a\": 1, \"b\": \"\"}]", + "[{\"a\": 2, \"b\": \"hello\"},{\"a\": 4, \"b\": \"eh\"}]"}; + + this->AssertTake(schm, table_json, "[]", {"[]"}); + std::vector expected_310 = { + "[{\"a\": 4, \"b\": \"eh\"},{\"a\": 1, \"b\": \"\"},{\"a\": null, \"b\": \"yo\"}]"}; + this->AssertTake(schm, table_json, "[3, 1, 0]", expected_310); + this->AssertChunkedTake(schm, table_json, {"[0, 1]", "[2, 3]"}, table_json); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 393e7a977a3..b0a7dbca42e 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -151,6 +151,15 @@ std::shared_ptr ArrayFromJSON(const std::shared_ptr& type, return out; } +std::shared_ptr ChunkedArrayFromJSON(const std::shared_ptr& type, + const std::vector& json) { + ArrayVector out_chunks; + for (const std::string& chunk_json : json) { + out_chunks.push_back(ArrayFromJSON(type, chunk_json)); + } + return std::make_shared(std::move(out_chunks)); +} + std::shared_ptr RecordBatchFromJSON(const std::shared_ptr& schema, util::string_view json) { // Parses as a StructArray @@ -164,6 +173,18 @@ std::shared_ptr RecordBatchFromJSON(const std::shared_ptr& return record_batch; } +std::shared_ptr
TableFromJSON(const std::shared_ptr& schema, + const std::vector& json) { + std::vector> batches; + for (const std::string& batch_json : json) { + batches.push_back(RecordBatchFromJSON(schema, batch_json)); + } + std::shared_ptr
table; + ABORT_NOT_OK(Table::FromRecordBatches(schema, batches, &table)); + + return table; +} + void AssertTablesEqual(const Table& expected, const Table& actual, bool same_chunk_layout, bool combine_chunks) { ASSERT_EQ(expected.num_columns(), actual.num_columns()); diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 8bae17f8e75..209fa295bc3 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -152,6 +152,7 @@ using ArrayVector = std::vector>; #define ASSERT_ARRAYS_EQUAL(lhs, rhs) AssertArraysEqual((lhs), (rhs)) #define ASSERT_BATCHES_EQUAL(lhs, rhs) AssertBatchesEqual((lhs), (rhs)) +#define ASSERT_TABLES_EQUAL(lhs, rhs) AssertTablesEqual((lhs), (rhs)) // If verbose is true, then the arrays will be pretty printed ARROW_EXPORT void AssertArraysEqual(const Array& expected, const Array& actual, @@ -213,6 +214,14 @@ std::shared_ptr ArrayFromJSON(const std::shared_ptr&, ARROW_EXPORT std::shared_ptr RecordBatchFromJSON( const std::shared_ptr&, util::string_view); +ARROW_EXPORT +std::shared_ptr ChunkedArrayFromJSON(const std::shared_ptr&, + const std::vector& json); + +ARROW_EXPORT +std::shared_ptr
TableFromJSON(const std::shared_ptr&, + const std::vector& json); + // ArrayFromVector: construct an Array from vectors of C values template diff --git a/r/R/array.R b/r/R/array.R index 031b38ad6d1..d05d91d35a5 100644 --- a/r/R/array.R +++ b/r/R/array.R @@ -100,7 +100,10 @@ Array <- R6Class("Array", if (is.integer(i)) { i <- Array$create(i) } - assert_is(i, "Array") # Support ChunkedArray too? + if (inherits(i, "ChunkedArray")) { + return(shared_ptr(ChunkedArray, Array__TakeChunked(self, i))) + } + assert_is(i, "Array") shared_ptr(Array, Array__Take(self, i)) }, Filter = function(i) { diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index e68a577ab76..7bf434a2c3e 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -272,6 +272,10 @@ Array__Take <- function(values, indices){ .Call(`_arrow_Array__Take` , values, indices) } +Array__TakeChunked <- function(values, indices){ + .Call(`_arrow_Array__TakeChunked` , values, indices) +} + RecordBatch__Take <- function(batch, indices){ .Call(`_arrow_RecordBatch__Take` , batch, indices) } @@ -280,10 +284,18 @@ ChunkedArray__Take <- function(values, indices){ .Call(`_arrow_ChunkedArray__Take` , values, indices) } +ChunkedArray__TakeChunked <- function(values, indices){ + .Call(`_arrow_ChunkedArray__TakeChunked` , values, indices) +} + Table__Take <- function(table, indices){ .Call(`_arrow_Table__Take` , table, indices) } +Table__TakeChunked <- function(table, indices){ + .Call(`_arrow_Table__TakeChunked` , table, indices) +} + Array__Filter <- function(values, filter){ .Call(`_arrow_Array__Filter` , values, filter) } diff --git a/r/R/chunked-array.R b/r/R/chunked-array.R index 6e977f2422a..5711a4627e2 100644 --- a/r/R/chunked-array.R +++ b/r/R/chunked-array.R @@ -69,13 +69,16 @@ ChunkedArray <- R6Class("ChunkedArray", inherit = Object, } }, Take = function(i) { - if (inherits(i, c("Array", "ChunkedArray"))) { - # Hack because ChunkedArray__Take doesn't take Arrays - i <- as.vector(i) - } else if (is.numeric(i)) { + if (is.numeric(i)) { i <- as.integer(i) } - assert_is(i, "integer") + if (is.integer(i)) { + i <- Array$create(i) + } + if (inherits(i, "ChunkedArray")) { + return(shared_ptr(ChunkedArray, ChunkedArray__TakeChunked(self, i))) + } + assert_is(i, "Array") return(shared_ptr(ChunkedArray, ChunkedArray__Take(self, i))) }, Filter = function(i) { diff --git a/r/R/table.R b/r/R/table.R index 16a869abe95..3732e1447d1 100644 --- a/r/R/table.R +++ b/r/R/table.R @@ -139,13 +139,16 @@ Table <- R6Class("Table", inherit = Object, } }, Take = function(i) { - if (inherits(i, c("Array", "ChunkedArray"))) { - # Hack because ChunkedArray__Take doesn't take Arrays - i <- as.vector(i) - } else if (is.numeric(i)) { + if (is.numeric(i)) { i <- as.integer(i) } - assert_is(i, "integer") + if (is.integer(i)) { + i <- Array$create(i) + } + if (inherits(i, "ChunkedArray")) { + return(shared_ptr(Table, Table__TakeChunked(self, i))) + } + assert_is(i, "Array") shared_ptr(Table, Table__Take(self, i)) }, Filter = function(i) { diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index ac7777f5af5..3fd3d7a8b30 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -1068,6 +1068,22 @@ RcppExport SEXP _arrow_Array__Take(SEXP values_sexp, SEXP indices_sexp){ } #endif +// compute.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr Array__TakeChunked(const std::shared_ptr& values, const std::shared_ptr& indices); +RcppExport SEXP _arrow_Array__TakeChunked(SEXP values_sexp, SEXP indices_sexp){ +BEGIN_RCPP + Rcpp::traits::input_parameter&>::type values(values_sexp); + Rcpp::traits::input_parameter&>::type indices(indices_sexp); + return Rcpp::wrap(Array__TakeChunked(values, indices)); +END_RCPP +} +#else +RcppExport SEXP _arrow_Array__TakeChunked(SEXP values_sexp, SEXP indices_sexp){ + Rf_error("Cannot call Array__TakeChunked(). Please use arrow::install_arrow() to install required runtime libraries. "); +} +#endif + // compute.cpp #if defined(ARROW_R_WITH_ARROW) std::shared_ptr RecordBatch__Take(const std::shared_ptr& batch, const std::shared_ptr& indices); @@ -1086,11 +1102,11 @@ RcppExport SEXP _arrow_RecordBatch__Take(SEXP batch_sexp, SEXP indices_sexp){ // compute.cpp #if defined(ARROW_R_WITH_ARROW) -std::shared_ptr ChunkedArray__Take(const std::shared_ptr& values, Rcpp::IntegerVector& indices); +std::shared_ptr ChunkedArray__Take(const std::shared_ptr& values, const std::shared_ptr& indices); RcppExport SEXP _arrow_ChunkedArray__Take(SEXP values_sexp, SEXP indices_sexp){ BEGIN_RCPP Rcpp::traits::input_parameter&>::type values(values_sexp); - Rcpp::traits::input_parameter::type indices(indices_sexp); + Rcpp::traits::input_parameter&>::type indices(indices_sexp); return Rcpp::wrap(ChunkedArray__Take(values, indices)); END_RCPP } @@ -1102,11 +1118,27 @@ RcppExport SEXP _arrow_ChunkedArray__Take(SEXP values_sexp, SEXP indices_sexp){ // compute.cpp #if defined(ARROW_R_WITH_ARROW) -std::shared_ptr Table__Take(const std::shared_ptr& table, Rcpp::IntegerVector& indices); +std::shared_ptr ChunkedArray__TakeChunked(const std::shared_ptr& values, const std::shared_ptr& indices); +RcppExport SEXP _arrow_ChunkedArray__TakeChunked(SEXP values_sexp, SEXP indices_sexp){ +BEGIN_RCPP + Rcpp::traits::input_parameter&>::type values(values_sexp); + Rcpp::traits::input_parameter&>::type indices(indices_sexp); + return Rcpp::wrap(ChunkedArray__TakeChunked(values, indices)); +END_RCPP +} +#else +RcppExport SEXP _arrow_ChunkedArray__TakeChunked(SEXP values_sexp, SEXP indices_sexp){ + Rf_error("Cannot call ChunkedArray__TakeChunked(). Please use arrow::install_arrow() to install required runtime libraries. "); +} +#endif + +// compute.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr Table__Take(const std::shared_ptr& table, const std::shared_ptr& indices); RcppExport SEXP _arrow_Table__Take(SEXP table_sexp, SEXP indices_sexp){ BEGIN_RCPP Rcpp::traits::input_parameter&>::type table(table_sexp); - Rcpp::traits::input_parameter::type indices(indices_sexp); + Rcpp::traits::input_parameter&>::type indices(indices_sexp); return Rcpp::wrap(Table__Take(table, indices)); END_RCPP } @@ -1116,6 +1148,22 @@ RcppExport SEXP _arrow_Table__Take(SEXP table_sexp, SEXP indices_sexp){ } #endif +// compute.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr Table__TakeChunked(const std::shared_ptr& table, const std::shared_ptr& indices); +RcppExport SEXP _arrow_Table__TakeChunked(SEXP table_sexp, SEXP indices_sexp){ +BEGIN_RCPP + Rcpp::traits::input_parameter&>::type table(table_sexp); + Rcpp::traits::input_parameter&>::type indices(indices_sexp); + return Rcpp::wrap(Table__TakeChunked(table, indices)); +END_RCPP +} +#else +RcppExport SEXP _arrow_Table__TakeChunked(SEXP table_sexp, SEXP indices_sexp){ + Rf_error("Cannot call Table__TakeChunked(). Please use arrow::install_arrow() to install required runtime libraries. "); +} +#endif + // compute.cpp #if defined(ARROW_R_WITH_ARROW) std::shared_ptr Array__Filter(const std::shared_ptr& values, const std::shared_ptr& filter); @@ -5033,9 +5081,12 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_RecordBatch__cast", (DL_FUNC) &_arrow_RecordBatch__cast, 3}, { "_arrow_Table__cast", (DL_FUNC) &_arrow_Table__cast, 3}, { "_arrow_Array__Take", (DL_FUNC) &_arrow_Array__Take, 2}, + { "_arrow_Array__TakeChunked", (DL_FUNC) &_arrow_Array__TakeChunked, 2}, { "_arrow_RecordBatch__Take", (DL_FUNC) &_arrow_RecordBatch__Take, 2}, { "_arrow_ChunkedArray__Take", (DL_FUNC) &_arrow_ChunkedArray__Take, 2}, + { "_arrow_ChunkedArray__TakeChunked", (DL_FUNC) &_arrow_ChunkedArray__TakeChunked, 2}, { "_arrow_Table__Take", (DL_FUNC) &_arrow_Table__Take, 2}, + { "_arrow_Table__TakeChunked", (DL_FUNC) &_arrow_Table__TakeChunked, 2}, { "_arrow_Array__Filter", (DL_FUNC) &_arrow_Array__Filter, 2}, { "_arrow_RecordBatch__Filter", (DL_FUNC) &_arrow_RecordBatch__Filter, 2}, { "_arrow_ChunkedArray__Filter", (DL_FUNC) &_arrow_ChunkedArray__Filter, 2}, diff --git a/r/src/compute.cpp b/r/src/compute.cpp index bfb5f6f01e2..6f78bcb6316 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -92,86 +92,74 @@ std::shared_ptr Array__Take(const std::shared_ptr& v return out; } +// [[arrow::export]] +std::shared_ptr Array__TakeChunked( + const std::shared_ptr& values, + const std::shared_ptr& indices) { + std::shared_ptr out; + arrow::compute::FunctionContext context; + arrow::compute::TakeOptions options; + + STOP_IF_NOT_OK(arrow::compute::Take(&context, *values, *indices, options, &out)); + return out; +} + // [[arrow::export]] std::shared_ptr RecordBatch__Take( const std::shared_ptr& batch, const std::shared_ptr& indices) { - int ncols = batch->num_columns(); - auto nrows = indices->length(); - - std::vector> columns(ncols); - - for (R_xlen_t j = 0; j < ncols; j++) { - columns[j] = Array__Take(batch->column(j), indices); - } - - return arrow::RecordBatch::Make(batch->schema(), nrows, columns); + std::shared_ptr out; + arrow::compute::FunctionContext context; + arrow::compute::TakeOptions options; + STOP_IF_NOT_OK(arrow::compute::Take(&context, *batch, *indices, options, &out)); + return out; } // [[arrow::export]] std::shared_ptr ChunkedArray__Take( - const std::shared_ptr& values, Rcpp::IntegerVector& indices) { - int num_chunks = values->num_chunks(); - std::vector> new_chunks(1); // Hard-coded 1 for now - // 1) If there's only one chunk, just take from it - if (num_chunks == 1) { - new_chunks[0] = Array__Take( - values->chunk(0), arrow::r::Array__from_vector(indices, arrow::int32(), true)); - return std::make_shared(std::move(new_chunks)); - } - - std::shared_ptr current_chunk; - std::shared_ptr current_indices; - int offset = 0; - int len; - int min_i = indices[0]; - int max_i = indices[0]; + const std::shared_ptr& values, + const std::shared_ptr& indices) { + std::shared_ptr out; + arrow::compute::FunctionContext context; + arrow::compute::TakeOptions options; - // 2) See if all i are in the same chunk, call Array__Take on that - for (R_xlen_t i = 1; i < indices.size(); i++) { - if (indices[i] < min_i) { - min_i = indices[i]; - } else if (indices[i] > max_i) { - max_i = indices[i]; - } - } - for (R_xlen_t chk = 0; chk < num_chunks; chk++) { - current_chunk = values->chunk(chk); - len = current_chunk->length(); - if (min_i >= offset && max_i < offset + len) { - for (R_xlen_t i = 0; i < indices.size(); i++) { - // Subtract offset from all indices - indices[i] -= offset; - } - current_indices = arrow::r::Array__from_vector(indices, arrow::int32(), true); - new_chunks[0] = Array__Take(current_chunk, current_indices); - return std::make_shared(std::move(new_chunks)); - } - offset += len; - } + STOP_IF_NOT_OK(arrow::compute::Take(&context, *values, *indices, options, &out)); + return out; +} - // TODO 3) If they're not all in the same chunk but are sorted, we can slice - // the indices (offset appropriately) and take from each chunk +// [[arrow::export]] +std::shared_ptr ChunkedArray__TakeChunked( + const std::shared_ptr& values, + const std::shared_ptr& indices) { + std::shared_ptr out; + arrow::compute::FunctionContext context; + arrow::compute::TakeOptions options; - // 4) Last resort: concatenate the chunks - STOP_IF_NOT_OK( - arrow::Concatenate(values->chunks(), arrow::default_memory_pool(), ¤t_chunk)); - current_indices = arrow::r::Array__from_vector(indices, arrow::int32(), true); - new_chunks[0] = Array__Take(current_chunk, current_indices); - return std::make_shared(std::move(new_chunks)); + STOP_IF_NOT_OK(arrow::compute::Take(&context, *values, *indices, options, &out)); + return out; } // [[arrow::export]] std::shared_ptr Table__Take(const std::shared_ptr& table, - Rcpp::IntegerVector& indices) { - auto ncols = table->num_columns(); - std::vector> columns(ncols); + const std::shared_ptr& indices) { + std::shared_ptr out; + arrow::compute::FunctionContext context; + arrow::compute::TakeOptions options; - for (R_xlen_t j = 0; j < ncols; j++) { - columns[j] = ChunkedArray__Take(table->column(j), indices); - } + STOP_IF_NOT_OK(arrow::compute::Take(&context, *table, *indices, options, &out)); + return out; +} - return arrow::Table::Make(table->schema(), columns); +// [[arrow::export]] +std::shared_ptr Table__TakeChunked( + const std::shared_ptr& table, + const std::shared_ptr& indices) { + std::shared_ptr out; + arrow::compute::FunctionContext context; + arrow::compute::TakeOptions options; + + STOP_IF_NOT_OK(arrow::compute::Take(&context, *table, *indices, options, &out)); + return out; } // [[arrow::export]] @@ -187,92 +175,48 @@ std::shared_ptr Array__Filter(const std::shared_ptr& std::shared_ptr RecordBatch__Filter( const std::shared_ptr& batch, const std::shared_ptr& filter) { - int ncols = batch->num_columns(); - - std::vector> columns(ncols); - - for (R_xlen_t j = 0; j < ncols; j++) { - columns[j] = Array__Filter(batch->column(j), filter); - } - - return arrow::RecordBatch::Make(batch->schema(), columns[0]->length(), columns); + std::shared_ptr out; + arrow::compute::FunctionContext context; + STOP_IF_NOT_OK(arrow::compute::Filter(&context, *batch, *filter, &out)); + return out; } // [[arrow::export]] std::shared_ptr ChunkedArray__Filter( const std::shared_ptr& values, const std::shared_ptr& filter) { - int num_chunks = values->num_chunks(); - std::vector> new_chunks(num_chunks); - std::shared_ptr current_chunk; - int offset = 0; - int len; - - for (R_xlen_t i = 0; i < num_chunks; i++) { - current_chunk = values->chunk(i); - len = current_chunk->length(); - new_chunks[i] = Array__Filter(current_chunk, filter->Slice(offset, len)); - offset += len; - } - - return std::make_shared(std::move(new_chunks)); + std::shared_ptr out; + arrow::compute::FunctionContext context; + STOP_IF_NOT_OK(arrow::compute::Filter(&context, *values, *filter, &out)); + return out; } // [[arrow::export]] std::shared_ptr ChunkedArray__FilterChunked( const std::shared_ptr& values, const std::shared_ptr& filter) { - int num_chunks = values->num_chunks(); - std::vector> new_chunks(num_chunks); - std::shared_ptr current_chunk; - std::shared_ptr current_chunked_filter; - std::shared_ptr current_filter; - - int offset = 0; - int len; - - for (R_xlen_t i = 0; i < num_chunks; i++) { - current_chunk = values->chunk(i); - len = current_chunk->length(); - current_chunked_filter = filter->Slice(offset, len); - if (current_chunked_filter->num_chunks() == 1) { - current_filter = current_chunked_filter->chunk(0); - } else { - // Concatenate the chunks of the filter so we have an Array - STOP_IF_NOT_OK(arrow::Concatenate(current_chunked_filter->chunks(), - arrow::default_memory_pool(), ¤t_filter)); - } - new_chunks[i] = Array__Filter(current_chunk, current_filter); - offset += len; - } - - return std::make_shared(std::move(new_chunks)); + std::shared_ptr out; + arrow::compute::FunctionContext context; + STOP_IF_NOT_OK(arrow::compute::Filter(&context, *values, *filter, &out)); + return out; } // [[arrow::export]] std::shared_ptr Table__Filter(const std::shared_ptr& table, const std::shared_ptr& filter) { - auto ncols = table->num_columns(); - std::vector> columns(ncols); - - for (R_xlen_t j = 0; j < ncols; j++) { - columns[j] = ChunkedArray__Filter(table->column(j), filter); - } - - return arrow::Table::Make(table->schema(), columns); + std::shared_ptr out; + arrow::compute::FunctionContext context; + STOP_IF_NOT_OK(arrow::compute::Filter(&context, *table, *filter, &out)); + return out; } // [[arrow::export]] std::shared_ptr Table__FilterChunked( const std::shared_ptr& table, const std::shared_ptr& filter) { - auto ncols = table->num_columns(); - std::vector> columns(ncols); - - for (R_xlen_t j = 0; j < ncols; j++) { - columns[j] = ChunkedArray__FilterChunked(table->column(j), filter); - } - - return arrow::Table::Make(table->schema(), columns); + std::shared_ptr out; + arrow::compute::FunctionContext context; + STOP_IF_NOT_OK(arrow::compute::Filter(&context, *table, *filter, &out)); + return out; } #endif diff --git a/r/tests/testthat/test-Array.R b/r/tests/testthat/test-Array.R index 73677705afd..d5b7e7f69e7 100644 --- a/r/tests/testthat/test-Array.R +++ b/r/tests/testthat/test-Array.R @@ -499,11 +499,7 @@ test_that("[ accepts Arrays and otherwise handles bad input", { ) expect_vector(a[Array$create(ind - 1, type = int8())], vec[ind]) expect_vector(a[Array$create(ind - 1, type = uint8())], vec[ind]) - expect_error( - # Not currently supported - a[ChunkedArray$create(8, 2, 4, type = uint8())], - 'i must be a "Array"' - ) + expect_vector(a[ChunkedArray$create(8, 2, 4, type = uint8())], vec[ind]) filt <- seq_along(vec) %in% ind expect_vector(a[Array$create(filt)], vec[filt]) diff --git a/r/tests/testthat/test-Table.R b/r/tests/testthat/test-Table.R index fbc6274cd03..ff3bff4f191 100644 --- a/r/tests/testthat/test-Table.R +++ b/r/tests/testthat/test-Table.R @@ -110,6 +110,8 @@ test_that("[, [[, $ for Table", { expect_data_frame(tab[ca,], tbl[c(1, 3, 4, 8, 9),]) # int Array expect_data_frame(tab[Array$create(5:6), 2:4], tbl[6:7, 2:4]) + # ChunkedArray + expect_data_frame(tab[ChunkedArray$create(5L, 6L), 2:4], tbl[6:7, 2:4]) # Expression expect_data_frame(tab[tab$int > 6,], tbl[tbl$int > 6,]) diff --git a/r/tests/testthat/test-chunked-array.R b/r/tests/testthat/test-chunked-array.R index 02a92612799..1fa399db936 100644 --- a/r/tests/testthat/test-chunked-array.R +++ b/r/tests/testthat/test-chunked-array.R @@ -361,6 +361,12 @@ test_that("[ ChunkedArray", { expect_vector(x[c(11, 15, 12)], c(31, 35, 32)) # Take from multiple chunks (calls Concatenate) expect_vector(x[c(2, 11, 15, 12, 3)], c(2, 31, 35, 32, 3)) + # Take with Array (note these are 0-based) + take1 <- Array$create(c(10L, 14L, 11L)) + expect_vector(x[take1], c(31, 35, 32)) + # Take with ChunkedArray + take2 <- ChunkedArray$create(c(10L, 14L), 11L) + expect_vector(x[take2], c(31, 35, 32)) # Filter (with recycling) expect_vector(