Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 128 additions & 124 deletions cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "arrow/compute/kernels/gather_internal.h"
#include "arrow/compute/kernels/vector_selection_internal.h"
#include "arrow/compute/kernels/vector_selection_take_internal.h"
#include "arrow/compute/registry.h"
#include "arrow/memory_pool.h"
#include "arrow/record_batch.h"
#include "arrow/table.h"
Expand Down Expand Up @@ -536,142 +537,144 @@ Status ExtensionTake(KernelContext* ctx, const ExecSpan& batch, ExecResult* out)
// R -> RecordBatch
// T -> Table

Result<std::shared_ptr<ArrayData>> TakeAAA(const std::shared_ptr<ArrayData>& values,
const std::shared_ptr<ArrayData>& indices,
const TakeOptions& options, ExecContext* ctx) {
ARROW_ASSIGN_OR_RAISE(Datum result,
CallFunction("array_take", {values, indices}, &options, ctx));
return result.array();
}
const FunctionDoc take_doc(
"Select values from an input based on indices from another array",
("The output is populated with values from the input at positions\n"
"given by `indices`. Nulls in `indices` emit null in the output."),
{"input", "indices"}, "TakeOptions");

Result<std::shared_ptr<ChunkedArray>> TakeCAC(const ChunkedArray& values,
const Array& indices,
const TakeOptions& options,
ExecContext* ctx) {
std::shared_ptr<Array> values_array;
if (values.num_chunks() == 1) {
// Case 1: `values` has a single chunk, so just use it
values_array = values.chunk(0);
} else {
// TODO Case 2: See if all `indices` fall in the same chunk and call Array Take on it
// See
// https://github.com/apache/arrow/blob/6f2c9041137001f7a9212f244b51bc004efc29af/r/src/compute.cpp#L123-L151
// TODO Case 3: If indices are sorted, can slice them and call Array Take
// (these are relevant to TakeCCC as well)

// Case 4: Else, concatenate chunks and call Array Take
if (values.chunks().empty()) {
ARROW_ASSIGN_OR_RAISE(
values_array, MakeArrayOfNull(values.type(), /*length=*/0, ctx->memory_pool()));
} else {
ARROW_ASSIGN_OR_RAISE(values_array,
Concatenate(values.chunks(), ctx->memory_pool()));
}
// Metafunction for dispatching to different Take implementations other than
// Array-Array.
class TakeMetaFunction : public MetaFunction {
public:
TakeMetaFunction()
: MetaFunction("take", Arity::Binary(), take_doc, GetDefaultTakeOptions()) {}

static Result<Datum> CallArrayTake(const std::vector<Datum>& args,
const TakeOptions& options, ExecContext* ctx) {
ARROW_ASSIGN_OR_RAISE(auto array_take_func,
ctx->func_registry()->GetFunction("array_take"));
return array_take_func->Execute(args, &options, ctx);
}
// Call Array Take on our single chunk
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ArrayData> new_chunk,
TakeAAA(values_array->data(), indices.data(), options, ctx));
std::vector<std::shared_ptr<Array>> chunks = {MakeArray(new_chunk)};
return std::make_shared<ChunkedArray>(std::move(chunks));
}

Result<std::shared_ptr<ChunkedArray>> TakeCCC(const ChunkedArray& values,
const ChunkedArray& indices,
const TakeOptions& options,
ExecContext* ctx) {
// XXX: for every chunk in indices, values are gathered from all chunks in values to
// form a new chunk in the result. Performing this concatenation is not ideal, but
// greatly simplifies the implementation before something more efficient is
// implemented.
std::shared_ptr<Array> values_array;
if (values.num_chunks() == 1) {
values_array = values.chunk(0);
} else {
if (values.chunks().empty()) {
ARROW_ASSIGN_OR_RAISE(
values_array, MakeArrayOfNull(values.type(), /*length=*/0, ctx->memory_pool()));
} else {
ARROW_ASSIGN_OR_RAISE(values_array,
Concatenate(values.chunks(), ctx->memory_pool()));
static Result<std::shared_ptr<Array>> ChunkedArrayAsArray(
const std::shared_ptr<ChunkedArray>& values, MemoryPool* pool) {
switch (values->num_chunks()) {
case 0:
return MakeArrayOfNull(values->type(), /*length=*/0, pool);
case 1:
return values->chunk(0);
default:
return Concatenate(values->chunks(), pool);
}
}
std::vector<std::shared_ptr<Array>> new_chunks;
new_chunks.resize(indices.num_chunks());
for (int i = 0; i < indices.num_chunks(); i++) {
ARROW_ASSIGN_OR_RAISE(auto chunk, TakeAAA(values_array->data(),
indices.chunk(i)->data(), options, ctx));
new_chunks[i] = MakeArray(chunk);

private:
static Result<std::shared_ptr<ArrayData>> TakeAAA(const std::vector<Datum>& args,
const TakeOptions& options,
ExecContext* ctx) {
DCHECK_EQ(args[0].kind(), Datum::ARRAY);
DCHECK_EQ(args[1].kind(), Datum::ARRAY);
ARROW_ASSIGN_OR_RAISE(Datum result, CallArrayTake(args, options, ctx));
return result.array();
}
return std::make_shared<ChunkedArray>(std::move(new_chunks), values.type());
}

Result<std::shared_ptr<ChunkedArray>> TakeACC(const Array& values,
const ChunkedArray& indices,
const TakeOptions& options,
ExecContext* ctx) {
auto num_chunks = indices.num_chunks();
std::vector<std::shared_ptr<Array>> new_chunks(num_chunks);
for (int i = 0; i < num_chunks; i++) {
// Take with that indices chunk
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ArrayData> chunk,
TakeAAA(values.data(), indices.chunk(i)->data(), options, ctx));
new_chunks[i] = MakeArray(chunk);
static Result<std::shared_ptr<ArrayData>> TakeCAA(
const std::shared_ptr<ChunkedArray>& values, const Array& indices,
const TakeOptions& options, ExecContext* ctx) {
ARROW_ASSIGN_OR_RAISE(auto values_array,
ChunkedArrayAsArray(values, ctx->memory_pool()));
std::vector<Datum> args = {std::move(values_array), indices};
return TakeAAA(args, options, ctx);
}
return std::make_shared<ChunkedArray>(std::move(new_chunks), values.type());
}

Result<std::shared_ptr<RecordBatch>> TakeRAR(const RecordBatch& batch,
const Array& indices,
const TakeOptions& options,
ExecContext* ctx) {
auto ncols = batch.num_columns();
auto nrows = indices.length();
std::vector<std::shared_ptr<Array>> columns(ncols);
for (int j = 0; j < ncols; j++) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ArrayData> col_data,
TakeAAA(batch.column(j)->data(), indices.data(), options, ctx));
columns[j] = MakeArray(col_data);
static Result<std::shared_ptr<ChunkedArray>> TakeCAC(
const std::shared_ptr<ChunkedArray>& values, const Array& indices,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious: is there a particular reason for taking the first arg as shared_ptr-const-ref, and the second only as value-const-ref?

Copy link
Contributor Author

@felipecrv felipecrv Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Driven by the signatures of the Datum constructors [1]. Originally, they were almost all by simple &, then I tried making all of them shared_ptr& but that got overwhelming, so I settled on this mixed combination driven by the need (Array holds shared_ptr<ArrayData> which is what Datum cares about).

[1]

  /// \brief Construct from a Scalar
  Datum(std::shared_ptr<Scalar> value)  // NOLINT implicit conversion
      : value(std::move(value)) {}

  /// \brief Construct from an ArrayData
  Datum(std::shared_ptr<ArrayData> value)  // NOLINT implicit conversion
      : value(std::move(value)) {}

  /// \brief Construct from an ArrayData
  Datum(ArrayData arg)  // NOLINT implicit conversion
      : value(std::make_shared<ArrayData>(std::move(arg))) {}

  /// \brief Construct from an Array
  Datum(const Array& value);  // NOLINT implicit conversion

  /// \brief Construct from an Array
  Datum(const std::shared_ptr<Array>& value);  // NOLINT implicit conversion

  /// \brief Construct from a ChunkedArray
  Datum(std::shared_ptr<ChunkedArray> value);  // NOLINT implicit conversion

  /// \brief Construct from a RecordBatch
  Datum(std::shared_ptr<RecordBatch> value);  // NOLINT implicit conversion

  /// \brief Construct from a Table
  Datum(std::shared_ptr<Table> value);  // NOLINT implicit conversion

  /// \brief Construct from a ChunkedArray.
  ///
  /// This can be expensive, prefer the shared_ptr<ChunkedArray> constructor
  explicit Datum(const ChunkedArray& value);

  /// \brief Construct from a RecordBatch.
  ///
  /// This can be expensive, prefer the shared_ptr<RecordBatch> constructor
  explicit Datum(const RecordBatch& value);

  /// \brief Construct from a Table.
  ///
  /// This can be expensive, prefer the shared_ptr<Table> constructor
  explicit Datum(const Table& value);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to simplify these signatures further in the next PR as less will actually have to be handled by the "take" MetaFunction as "array_take" learns to handle chunked arrays by itself.

const TakeOptions& options, ExecContext* ctx) {
ARROW_ASSIGN_OR_RAISE(auto new_chunk, TakeCAA(values, indices, options, ctx));
return std::make_shared<ChunkedArray>(MakeArray(std::move(new_chunk)));
}
return RecordBatch::Make(batch.schema(), nrows, std::move(columns));
}

Result<std::shared_ptr<Table>> TakeTAT(const Table& table, const Array& indices,
const TakeOptions& options, ExecContext* ctx) {
auto ncols = table.num_columns();
std::vector<std::shared_ptr<ChunkedArray>> columns(ncols);
static Result<std::shared_ptr<ChunkedArray>> TakeCCC(
const std::shared_ptr<ChunkedArray>& values,
const std::shared_ptr<ChunkedArray>& indices, const TakeOptions& options,
ExecContext* ctx) {
// XXX: for every chunk in indices, values are gathered from all chunks in values to
// form a new chunk in the result. Performing this concatenation is not ideal, but
// greatly simplifies the implementation before something more efficient is
// implemented.
ARROW_ASSIGN_OR_RAISE(auto values_array,
ChunkedArrayAsArray(values, ctx->memory_pool()));
std::vector<Datum> args = {std::move(values_array), {}};
std::vector<std::shared_ptr<Array>> new_chunks;
new_chunks.resize(indices->num_chunks());
for (int i = 0; i < indices->num_chunks(); i++) {
args[1] = indices->chunk(i);
// XXX: this loop can use TakeCAA once it can handle ChunkedArray
// without concatenating first
ARROW_ASSIGN_OR_RAISE(auto chunk, TakeAAA(args, options, ctx));
new_chunks[i] = MakeArray(chunk);
}
return std::make_shared<ChunkedArray>(std::move(new_chunks), values->type());
}

for (int j = 0; j < ncols; j++) {
ARROW_ASSIGN_OR_RAISE(columns[j], TakeCAC(*table.column(j), indices, options, ctx));
static Result<std::shared_ptr<ChunkedArray>> TakeACC(const Array& values,
const ChunkedArray& indices,
const TakeOptions& options,
ExecContext* ctx) {
auto num_chunks = indices.num_chunks();
std::vector<std::shared_ptr<Array>> new_chunks(num_chunks);
std::vector<Datum> args = {values, {}};
for (int i = 0; i < num_chunks; i++) {
// Take with that indices chunk
args[1] = indices.chunk(i);
ARROW_ASSIGN_OR_RAISE(auto chunk, TakeAAA(args, options, ctx));
new_chunks[i] = MakeArray(chunk);
}
return std::make_shared<ChunkedArray>(std::move(new_chunks), values.type());
}
return Table::Make(table.schema(), std::move(columns));
}

Result<std::shared_ptr<Table>> TakeTCT(const Table& table, const ChunkedArray& indices,
const TakeOptions& options, ExecContext* ctx) {
auto ncols = table.num_columns();
std::vector<std::shared_ptr<ChunkedArray>> columns(ncols);
for (int j = 0; j < ncols; j++) {
ARROW_ASSIGN_OR_RAISE(columns[j], TakeCCC(*table.column(j), indices, options, ctx));
static Result<std::shared_ptr<RecordBatch>> TakeRAR(const RecordBatch& batch,
const Array& indices,
const TakeOptions& options,
ExecContext* ctx) {
auto ncols = batch.num_columns();
auto nrows = indices.length();
std::vector<std::shared_ptr<Array>> columns(ncols);
std::vector<Datum> args = {{}, indices};
for (int j = 0; j < ncols; j++) {
args[0] = batch.column(j);
ARROW_ASSIGN_OR_RAISE(auto col_data, TakeAAA(args, options, ctx));
columns[j] = MakeArray(col_data);
}
return RecordBatch::Make(batch.schema(), nrows, std::move(columns));
}
return Table::Make(table.schema(), std::move(columns));
}

const FunctionDoc take_doc(
"Select values from an input based on indices from another array",
("The output is populated with values from the input at positions\n"
"given by `indices`. Nulls in `indices` emit null in the output."),
{"input", "indices"}, "TakeOptions");
static Result<std::shared_ptr<Table>> TakeTAT(const std::shared_ptr<Table>& table,
const Array& indices,
const TakeOptions& options,
ExecContext* ctx) {
auto ncols = table->num_columns();
std::vector<std::shared_ptr<ChunkedArray>> columns(ncols);

// Metafunction for dispatching to different Take implementations other than
// Array-Array.
//
// TODO: Revamp approach to executing Take operations. In addition to being
// overly complex dispatching, there is no parallelization.
class TakeMetaFunction : public MetaFunction {
public:
TakeMetaFunction()
: MetaFunction("take", Arity::Binary(), take_doc, GetDefaultTakeOptions()) {}
for (int j = 0; j < ncols; j++) {
ARROW_ASSIGN_OR_RAISE(columns[j], TakeCAC(table->column(j), indices, options, ctx));
}
return Table::Make(table->schema(), std::move(columns));
}

static Result<std::shared_ptr<Table>> TakeTCT(
const std::shared_ptr<Table>& table, const std::shared_ptr<ChunkedArray>& indices,
const TakeOptions& options, ExecContext* ctx) {
auto ncols = table->num_columns();
std::vector<std::shared_ptr<ChunkedArray>> columns(ncols);
for (int j = 0; j < ncols; j++) {
ARROW_ASSIGN_OR_RAISE(columns[j], TakeCCC(table->column(j), indices, options, ctx));
}
return Table::Make(table->schema(), std::move(columns));
}

public:
Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
const FunctionOptions* options,
ExecContext* ctx) const override {
Expand All @@ -680,16 +683,16 @@ class TakeMetaFunction : public MetaFunction {
switch (args[0].kind()) {
case Datum::ARRAY:
if (index_kind == Datum::ARRAY) {
return TakeAAA(args[0].array(), args[1].array(), take_opts, ctx);
return TakeAAA(args, take_opts, ctx);
} else if (index_kind == Datum::CHUNKED_ARRAY) {
return TakeACC(*args[0].make_array(), *args[1].chunked_array(), take_opts, ctx);
}
break;
case Datum::CHUNKED_ARRAY:
if (index_kind == Datum::ARRAY) {
return TakeCAC(*args[0].chunked_array(), *args[1].make_array(), take_opts, ctx);
return TakeCAC(args[0].chunked_array(), *args[1].make_array(), take_opts, ctx);
} else if (index_kind == Datum::CHUNKED_ARRAY) {
return TakeCCC(*args[0].chunked_array(), *args[1].chunked_array(), take_opts,
return TakeCCC(args[0].chunked_array(), args[1].chunked_array(), take_opts,
ctx);
}
break;
Expand All @@ -700,12 +703,13 @@ class TakeMetaFunction : public MetaFunction {
break;
case Datum::TABLE:
if (index_kind == Datum::ARRAY) {
return TakeTAT(*args[0].table(), *args[1].make_array(), take_opts, ctx);
return TakeTAT(args[0].table(), *args[1].make_array(), take_opts, ctx);
} else if (index_kind == Datum::CHUNKED_ARRAY) {
return TakeTCT(*args[0].table(), *args[1].chunked_array(), take_opts, ctx);
return TakeTCT(args[0].table(), args[1].chunked_array(), take_opts, ctx);
}
break;
default:
case Datum::NONE:
case Datum::SCALAR:
break;
}
return Status::NotImplemented(
Expand Down