From 74dfe242cad5911460b2d9f9aab28c05bd8295bf Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 3 Aug 2021 17:33:30 -0400
Subject: [PATCH 1/6] ARROW-13540: [C++] Add order by sink node
---
cpp/src/arrow/CMakeLists.txt | 1 +
cpp/src/arrow/compute/exec/options.h | 14 ++
cpp/src/arrow/compute/exec/order_by_node.cc | 137 ++++++++++++++++++++
cpp/src/arrow/compute/exec/plan_test.cc | 120 +++++++++++++++++
cpp/src/arrow/compute/exec/sink_node.cc | 70 +++++++++-
5 files changed, 340 insertions(+), 2 deletions(-)
create mode 100644 cpp/src/arrow/compute/exec/order_by_node.cc
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index 308ee49972c..90fb2fae093 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -374,6 +374,7 @@ if(ARROW_COMPUTE)
compute/exec/exec_plan.cc
compute/exec/expression.cc
compute/exec/filter_node.cc
+ compute/exec/order_by_node.cc
compute/exec/project_node.cc
compute/exec/source_node.cc
compute/exec/sink_node.cc
diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h
index a853a74362d..7ea21d707ae 100644
--- a/cpp/src/arrow/compute/exec/options.h
+++ b/cpp/src/arrow/compute/exec/options.h
@@ -23,6 +23,7 @@
#include
#include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/api_vector.h"
#include "arrow/compute/exec.h"
#include "arrow/compute/exec/expression.h"
#include "arrow/util/optional.h"
@@ -63,6 +64,19 @@ class ARROW_EXPORT FilterNodeOptions : public ExecNodeOptions {
Expression filter_expression;
};
+/// \brief Make a node which sorts rows passed through it
+///
+/// All batches pushed to this node will be accumulated, then sorted, by the given
+/// fields. Then sorted batches will be pushed to the next node, along a tag
+/// indicating the absolute order of the batches.
+class ARROW_EXPORT OrderByNodeOptions : public ExecNodeOptions {
+ public:
+ explicit OrderByNodeOptions(SortOptions sort_options)
+ : sort_options(std::move(sort_options)) {}
+
+ SortOptions sort_options;
+};
+
/// \brief Make a node which executes expressions on input batches, producing new batches.
///
/// Each expression will be evaluated against each batch which is pushed to
diff --git a/cpp/src/arrow/compute/exec/order_by_node.cc b/cpp/src/arrow/compute/exec/order_by_node.cc
new file mode 100644
index 00000000000..e0f2445796e
--- /dev/null
+++ b/cpp/src/arrow/compute/exec/order_by_node.cc
@@ -0,0 +1,137 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/exec_plan.h"
+
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/table.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+namespace compute {
+
+using arrow::internal::checked_cast;
+
+// Simple in-memory sort node. Accumulates all data, then sorts and
+// emits output batches in order.
+struct OrderByNode final : public ExecNode {
+ OrderByNode(ExecPlan* plan, std::vector inputs, SortOptions sort_options)
+ : ExecNode(plan, std::move(inputs), {"target"}, inputs[0]->output_schema(),
+ /*num_outputs=*/1),
+ sort_options_(std::move(sort_options)) {}
+
+ const char* kind_name() override { return "OrderByNode"; }
+
+ static Result Make(ExecPlan* plan, std::vector inputs,
+ const ExecNodeOptions& options) {
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "OrderByNode"));
+
+ const auto& order_by_options = checked_cast(options);
+ std::vector fields;
+ fields.reserve((order_by_options.sort_options.sort_keys.size()));
+ for (const auto& key : order_by_options.sort_options.sort_keys)
+ fields.push_back(key.name);
+ RETURN_NOT_OK(inputs[0]->output_schema()->CanReferenceFieldsByNames(fields));
+
+ return plan->EmplaceNode(plan, std::move(inputs),
+ order_by_options.sort_options);
+ }
+
+ Status StartProducing() override {
+ finished_ = Future<>::Make();
+ return Status::OK();
+ }
+
+ void PauseProducing(ExecNode* output) override {}
+
+ void ResumeProducing(ExecNode* output) override {}
+
+ void StopProducing(ExecNode* output) override {
+ DCHECK_EQ(output, outputs_[0]);
+ StopProducing();
+ }
+
+ void StopProducing() override { inputs_[0]->StopProducing(this); }
+
+ Future<> finished() override { return finished_; }
+
+ void InputReceived(ExecNode* input, int seq, ExecBatch batch) override {
+ DCHECK_EQ(input, inputs_[0]);
+
+ // Accumulate data
+ auto maybe_batch = batch.ToRecordBatch(inputs_[0]->output_schema(),
+ plan()->exec_context()->memory_pool());
+ if (ErrorIfNotOk(maybe_batch.status())) return;
+ batches_.push_back(maybe_batch.MoveValueUnsafe());
+
+ if (input_counter_.Increment()) {
+ ErrorIfNotOk(Finish());
+ }
+ }
+
+ void ErrorReceived(ExecNode* input, Status error) override {
+ DCHECK_EQ(input, inputs_[0]);
+ outputs_[0]->ErrorReceived(this, std::move(error));
+ }
+
+ void InputFinished(ExecNode* input, int seq_stop) override {
+ if (input_counter_.SetTotal(seq_stop)) {
+ ErrorIfNotOk(Finish());
+ }
+ }
+
+ private:
+ Status Finish() {
+ ARROW_ASSIGN_OR_RAISE(
+ auto table,
+ Table::FromRecordBatches(inputs_[0]->output_schema(), std::move(batches_)));
+ ARROW_ASSIGN_OR_RAISE(auto indices,
+ SortIndices(table, sort_options_, plan()->exec_context()));
+ ARROW_ASSIGN_OR_RAISE(auto sorted, Take(table, indices, TakeOptions::NoBoundsCheck(),
+ plan()->exec_context()));
+
+ TableBatchReader reader(*sorted.table());
+ int64_t count = 0;
+ while (true) {
+ std::shared_ptr batch;
+ RETURN_NOT_OK(reader.ReadNext(&batch));
+ if (!batch) break;
+ ExecBatch exec_batch(*batch);
+ exec_batch.values.emplace_back(count);
+ outputs_[0]->InputReceived(this, static_cast(count), std::move(exec_batch));
+ count++;
+ }
+
+ outputs_[0]->InputFinished(this, static_cast(count));
+ finished_.MarkFinished();
+ return Status::OK();
+ }
+
+ SortOptions sort_options_;
+ std::vector> batches_;
+ AtomicCounter input_counter_;
+ Future<> finished_;
+};
+
+ExecFactoryRegistry::AddOnLoad kRegisterOrderBy("order_by", OrderByNode::Make);
+
+} // namespace compute
+} // namespace arrow
diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc
index 3b3d39fd36a..9e46ead21e8 100644
--- a/cpp/src/arrow/compute/exec/plan_test.cc
+++ b/cpp/src/arrow/compute/exec/plan_test.cc
@@ -26,6 +26,7 @@
#include "arrow/compute/exec/options.h"
#include "arrow/compute/exec/test_util.h"
#include "arrow/record_batch.h"
+#include "arrow/table.h"
#include "arrow/testing/future_util.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/matchers.h"
@@ -262,6 +263,7 @@ BatchesWithSchema MakeBasicBatches() {
BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema,
int num_batches = 10, int batch_size = 4) {
BatchesWithSchema out;
+ out.schema = schema;
random::RandomArrayGenerator rng(42);
out.batches.resize(num_batches);
@@ -301,6 +303,37 @@ TEST(ExecPlanExecution, SourceSink) {
}
}
+TEST(ExecPlanExecution, SourceOrderBy) {
+ std::vector expected = {
+ ExecBatchFromJSON({int32(), boolean()},
+ "[[4, false], [5, null], [6, false], [7, false], [null, true]]")};
+ for (bool slow : {false, true}) {
+ SCOPED_TRACE(slow ? "slowed" : "unslowed");
+
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel" : "single threaded");
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator> sink_gen;
+
+ auto basic_data = MakeBasicBatches();
+
+ SortOptions options({SortKey("i32", SortOrder::Ascending)});
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{basic_data.schema,
+ basic_data.gen(parallel, slow)}},
+ {"order_by", OrderByNodeOptions(options)},
+ {"reorder", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(::testing::ElementsAreArray(expected))));
+ }
+ }
+}
+
TEST(ExecPlanExecution, SourceSinkError) {
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
AsyncGenerator> sink_gen;
@@ -355,6 +388,53 @@ TEST(ExecPlanExecution, StressSourceSink) {
}
}
+TEST(ExecPlanExecution, StressSourceOrderBy) {
+ auto input_schema = schema({field("a", int32()), field("b", boolean())});
+ for (bool slow : {false, true}) {
+ SCOPED_TRACE(slow ? "slowed" : "unslowed");
+
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel" : "single threaded");
+
+ int num_batches = slow && !parallel ? 30 : 300;
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator> sink_gen;
+
+ auto random_data = MakeRandomBatches(input_schema, num_batches);
+
+ SortOptions options({SortKey("a", SortOrder::Ascending)});
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{random_data.schema,
+ random_data.gen(parallel, slow)}},
+ {"order_by", OrderByNodeOptions(options)},
+ {"reorder", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ // Check that data is sorted appropriately
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto exec_batches,
+ StartAndCollect(plan.get(), sink_gen));
+ RecordBatchVector batches, original_batches;
+ for (const auto& batch : exec_batches) {
+ ASSERT_OK_AND_ASSIGN(auto rb, batch.ToRecordBatch(input_schema));
+ batches.push_back(std::move(rb));
+ }
+ for (const auto& batch : random_data.batches) {
+ ASSERT_OK_AND_ASSIGN(auto rb, batch.ToRecordBatch(input_schema));
+ original_batches.push_back(std::move(rb));
+ }
+ ASSERT_OK_AND_ASSIGN(auto actual, Table::FromRecordBatches(input_schema, batches));
+ ASSERT_OK_AND_ASSIGN(auto original,
+ Table::FromRecordBatches(input_schema, original_batches));
+ ASSERT_OK_AND_ASSIGN(auto sort_indices, SortIndices(original, options));
+ ASSERT_OK_AND_ASSIGN(auto expected, Take(original, sort_indices));
+ AssertTablesEqual(*actual, *expected.table());
+ }
+ }
+}
+
TEST(ExecPlanExecution, StressSourceSinkStopped) {
for (bool slow : {false, true}) {
SCOPED_TRACE(slow ? "slowed" : "unslowed");
@@ -541,6 +621,46 @@ TEST(ExecPlanExecution, SourceFilterProjectGroupedSumFilter) {
}
}
+TEST(ExecPlanExecution, SourceFilterProjectGroupedSumOrderBy) {
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel/merged" : "serial");
+
+ int batch_multiplicity = parallel ? 100 : 1;
+ auto input = MakeGroupableBatches(/*multiplicity=*/batch_multiplicity);
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator> sink_gen;
+
+ SortOptions options({SortKey("str", SortOrder::Descending)});
+ ASSERT_OK(
+ Declaration::Sequence(
+ {
+ {"source",
+ SourceNodeOptions{input.schema, input.gen(parallel, /*slow=*/false)}},
+ {"filter",
+ FilterNodeOptions{greater_equal(field_ref("i32"), literal(0))}},
+ {"project", ProjectNodeOptions{{
+ field_ref("str"),
+ call("multiply", {field_ref("i32"), literal(2)}),
+ }}},
+ {"aggregate", AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr}},
+ /*targets=*/{"multiply(i32, 2)"},
+ /*names=*/{"sum(multiply(i32, 2))"},
+ /*keys=*/{"str"}}},
+ {"filter", FilterNodeOptions{greater(field_ref("sum(multiply(i32, 2))"),
+ literal(10 * batch_multiplicity))}},
+ {"order_by", OrderByNodeOptions{options}},
+ {"reorder", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(::testing::ElementsAreArray({ExecBatchFromJSON(
+ {int64(), utf8()}, parallel ? R"([[2000, "beta"], [3600, "alfa"]])"
+ : R"([[20, "beta"], [36, "alfa"]])")}))));
+ }
+}
+
TEST(ExecPlanExecution, SourceScalarAggSink) {
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
AsyncGenerator> sink_gen;
diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc
index e4a06e0d224..5388d81df62 100644
--- a/cpp/src/arrow/compute/exec/sink_node.cc
+++ b/cpp/src/arrow/compute/exec/sink_node.cc
@@ -19,6 +19,7 @@
#include "arrow/compute/exec/exec_plan.h"
#include
+#include
#include "arrow/compute/exec.h"
#include "arrow/compute/exec/expression.h"
@@ -135,8 +136,8 @@ class SinkNode : public ExecNode {
}
}
- private:
- void Finish() {
+ protected:
+ virtual void Finish() {
if (producer_.Close()) {
finished_.MarkFinished();
}
@@ -148,7 +149,72 @@ class SinkNode : public ExecNode {
PushGenerator>::Producer producer_;
};
+// A node that reorders inputs according to a tag. To be paired with OrderByNode.
+struct ReorderNode final : public SinkNode {
+ ReorderNode(ExecPlan* plan, std::vector inputs,
+ AsyncGenerator>* generator)
+ : SinkNode(plan, std::move(inputs), generator) {}
+
+ const char* kind_name() override { return "ReorderNode"; }
+
+ static Result Make(ExecPlan* plan, std::vector inputs,
+ const ExecNodeOptions& options) {
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "ReorderNode"));
+
+ const auto& sink_options = checked_cast(options);
+ return plan->EmplaceNode(plan, std::move(inputs),
+ sink_options.generator);
+ }
+
+ void InputReceived(ExecNode* input, int seq, ExecBatch batch) override {
+ DCHECK_EQ(input, inputs_[0]);
+
+ if (input_counter_.Increment()) {
+ Finish();
+ return;
+ }
+ std::unique_lock lock(mutex_);
+ const auto& tag_scalar = *batch.values.back().scalar();
+ const int64_t tag = checked_cast(tag_scalar).value;
+ batch.values.pop_back();
+ PushAvailable();
+ if (tag == next_batch_index_) {
+ next_batch_index_++;
+ producer_.Push(std::move(batch));
+ } else {
+ batches_.emplace(tag, std::move(batch));
+ }
+ }
+
+ protected:
+ void PushAvailable() {
+ decltype(batches_)::iterator it;
+ while ((it = batches_.find(next_batch_index_)) != batches_.end()) {
+ auto batch = std::move(it->second);
+ bool did_push = producer_.Push(std::move(batch));
+ batches_.erase(it);
+ // producer was Closed already
+ if (!did_push) return;
+ next_batch_index_++;
+ }
+ }
+
+ void Finish() override {
+ {
+ std::unique_lock lock(mutex_);
+ PushAvailable();
+ }
+ SinkNode::Finish();
+ }
+
+ private:
+ std::unordered_map batches_;
+ std::mutex mutex_;
+ int64_t next_batch_index_ = 0;
+};
+
ExecFactoryRegistry::AddOnLoad kRegisterSink("sink", SinkNode::Make);
+ExecFactoryRegistry::AddOnLoad kRegisterReorder("reorder", ReorderNode::Make);
} // namespace
} // namespace compute
From 5daae9f815c4dce8b9afb6647af0424a563ffebf Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 6 Aug 2021 17:19:16 -0400
Subject: [PATCH 2/6] ARROW-13540: [C++] Draft a hash_arg_min_max kernel
---
cpp/src/arrow/compute/exec/aggregate_node.cc | 14 +-
cpp/src/arrow/compute/exec/plan_test.cc | 46 ++++
.../arrow/compute/kernels/hash_aggregate.cc | 248 +++++++++++++++++-
3 files changed, 301 insertions(+), 7 deletions(-)
diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc b/cpp/src/arrow/compute/exec/aggregate_node.cc
index de9078cd07e..261e69c956d 100644
--- a/cpp/src/arrow/compute/exec/aggregate_node.cc
+++ b/cpp/src/arrow/compute/exec/aggregate_node.cc
@@ -362,9 +362,17 @@ struct GroupByNode : ExecNode {
KernelContext kernel_ctx{ctx_};
kernel_ctx.SetState(state->agg_states[i].get());
- ARROW_ASSIGN_OR_RAISE(
- auto agg_batch,
- ExecBatch::Make({batch.values[agg_src_field_ids_[i]], id_batch}));
+ ExecBatch agg_batch;
+ if (agg_kernels_[i]->signature->in_types().size() == 2) {
+ ARROW_ASSIGN_OR_RAISE(
+ agg_batch, ExecBatch::Make({batch.values[agg_src_field_ids_[i]], id_batch}));
+ } else {
+ // Order-dependent-kernel; assume an upstream OrderByNode has
+ // placed the batch index as the last value
+ ARROW_ASSIGN_OR_RAISE(
+ agg_batch, ExecBatch::Make({batch.values[agg_src_field_ids_[i]], id_batch,
+ batch.values.back()}));
+ }
RETURN_NOT_OK(agg_kernels_[i]->resize(&kernel_ctx, state->grouper->num_groups()));
RETURN_NOT_OK(agg_kernels_[i]->consume(&kernel_ctx, agg_batch));
diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc
index 9e46ead21e8..0b05840a9c2 100644
--- a/cpp/src/arrow/compute/exec/plan_test.cc
+++ b/cpp/src/arrow/compute/exec/plan_test.cc
@@ -661,6 +661,52 @@ TEST(ExecPlanExecution, SourceFilterProjectGroupedSumOrderBy) {
}
}
+TEST(ExecPlanExecution, SourceOrderByGroupSink) {
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel/merged" : "serial");
+
+ int batch_multiplicity = parallel ? 1000 : 1;
+ auto input = MakeGroupableBatches(/*multiplicity=*/batch_multiplicity);
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator> sink_gen;
+
+ SortOptions options({SortKey("str", SortOrder::Ascending)});
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{input.schema,
+ input.gen(parallel, /*slow=*/false)}},
+ {"order_by", OrderByNodeOptions{options}},
+ {"aggregate", AggregateNodeOptions{
+ /*aggregates=*/{{"hash_arg_min_max", nullptr}},
+ /*targets=*/{"i32"},
+ /*names=*/{"arg_min_max(i32)"},
+ /*keys=*/{"str"}}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(::testing::ElementsAreArray(
+ {ExecBatchFromJSON({struct_({
+ field("min", int64()),
+ field("max", int64()),
+ }),
+ utf8()},
+ parallel ?
+ R"([
+ [{"min": 4, "max": 0}, "alfa"],
+ [{"min": 5001, "max": 5000}, "beta"],
+ [{"min": 7000, "max": 7001}, "gama"]
+])"
+ : R"([
+ [{"min": 4, "max": 0}, "alfa"],
+ [{"min": 6, "max": 5}, "beta"],
+ [{"min": 7, "max": 8}, "gama"]
+])")}))));
+ }
+}
+
TEST(ExecPlanExecution, SourceScalarAggSink) {
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
AsyncGenerator> sink_gen;
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
index b3d602a89ac..f9bf93116cc 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
@@ -824,6 +824,18 @@ Status AddHashAggKernels(
return Status::OK();
}
+HashAggregateKernel MakeOrderDependentKernel(InputType argument_type, KernelInit init) {
+ HashAggregateKernel kernel = MakeKernel(argument_type, init);
+ kernel.signature = KernelSignature::Make(
+ {std::move(argument_type), InputType::Array(Type::UINT32),
+ InputType::Scalar(Type::INT64)},
+ OutputType(
+ [](KernelContext* ctx, const std::vector&) -> Result {
+ return checked_cast(ctx->state())->out_type();
+ }));
+ return kernel;
+}
+
// ----------------------------------------------------------------------
// Count implementation
@@ -1700,6 +1712,207 @@ struct GroupedMinMaxFactory {
InputType argument_type;
};
+// ----------------------------------------------------------------------
+// ArgMinMax implementation
+
+template
+struct GroupedArgMinMaxImpl : public GroupedAggregator {
+ using CType = typename TypeTraits::CType;
+
+ Status Init(ExecContext* ctx, const FunctionOptions* options) override {
+ options_ = *checked_cast(options);
+ mins_ = TypedBufferBuilder(ctx->memory_pool());
+ maxes_ = TypedBufferBuilder(ctx->memory_pool());
+ min_offsets_ = TypedBufferBuilder(ctx->memory_pool());
+ max_offsets_ = TypedBufferBuilder(ctx->memory_pool());
+ min_batch_indices_ = TypedBufferBuilder(ctx->memory_pool());
+ max_batch_indices_ = TypedBufferBuilder(ctx->memory_pool());
+ has_values_ = TypedBufferBuilder(ctx->memory_pool());
+ has_nulls_ = TypedBufferBuilder(ctx->memory_pool());
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ auto added_groups = new_num_groups - num_groups_;
+ num_groups_ = new_num_groups;
+ RETURN_NOT_OK(mins_.Append(added_groups, AntiExtrema::anti_min()));
+ RETURN_NOT_OK(maxes_.Append(added_groups, AntiExtrema::anti_max()));
+ RETURN_NOT_OK(min_offsets_.Append(added_groups, -1));
+ RETURN_NOT_OK(max_offsets_.Append(added_groups, -1));
+ RETURN_NOT_OK(min_batch_indices_.Append(added_groups, -1));
+ RETURN_NOT_OK(max_batch_indices_.Append(added_groups, -1));
+ RETURN_NOT_OK(has_values_.Append(added_groups, false));
+ RETURN_NOT_OK(has_nulls_.Append(added_groups, false));
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override {
+ DCHECK_EQ(3, batch.num_values());
+ auto g = batch[1].array()->GetValues(1);
+ const Scalar& tag_scalar = *batch.values.back().scalar();
+ const int64_t batch_index = UnboxScalar::Unbox(tag_scalar);
+ auto raw_mins = reinterpret_cast(mins_.mutable_data());
+ auto raw_maxes = reinterpret_cast(maxes_.mutable_data());
+ auto max_offsets = max_offsets_.mutable_data();
+ auto max_batch_indices = max_batch_indices_.mutable_data();
+ auto min_offsets = min_offsets_.mutable_data();
+ auto min_batch_indices = min_batch_indices_.mutable_data();
+ batch_sizes_.emplace(batch_index, batch.length);
+
+ int64_t index = 0;
+ VisitArrayDataInline(
+ *batch[0].array(),
+ [&](CType val) {
+ if (val > raw_maxes[*g] || max_batch_indices[*g] < 0) {
+ raw_maxes[*g] = val;
+ max_offsets[*g] = index;
+ max_batch_indices[*g] = batch_index;
+ }
+ // TODO: test an array that contains the antiextreme
+ if (val < raw_mins[*g] || min_batch_indices[*g] < 0) {
+ raw_mins[*g] = val;
+ min_offsets[*g] = index;
+ min_batch_indices[*g] = batch_index;
+ }
+ BitUtil::SetBit(has_values_.mutable_data(), *g++);
+ index++;
+ },
+ [&] {
+ BitUtil::SetBit(has_nulls_.mutable_data(), *g++);
+ index++;
+ });
+ return Status::OK();
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ auto other = checked_cast(&raw_other);
+
+ batch_sizes_.insert(other->batch_sizes_.begin(), other->batch_sizes_.end());
+
+ // TODO: go back and clean up these casts
+ auto raw_mins = reinterpret_cast(mins_.mutable_data());
+ auto min_offsets = min_offsets_.mutable_data();
+ auto min_batch_indices = max_batch_indices_.mutable_data();
+ auto raw_maxes = reinterpret_cast(maxes_.mutable_data());
+ auto max_offsets = max_offsets_.mutable_data();
+ auto max_batch_indices = max_batch_indices_.mutable_data();
+
+ auto other_raw_mins = reinterpret_cast(other->mins_.data());
+ auto other_min_offsets = other->min_offsets_.mutable_data();
+ auto other_min_batch_indices = other->max_batch_indices_.mutable_data();
+ auto other_raw_maxes = reinterpret_cast(other->maxes_.data());
+ auto other_max_offsets = other->max_offsets_.mutable_data();
+ auto other_max_batch_indices = other->max_batch_indices_.mutable_data();
+
+ auto g = group_id_mapping.GetValues(1);
+ for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
+ if (other_raw_mins[other_g] < raw_mins[*g]) {
+ raw_mins[*g] = other_raw_mins[other_g];
+ min_offsets[*g] = other_min_offsets[other_g];
+ min_batch_indices[*g] = other_min_batch_indices[other_g];
+ } else if (other_raw_mins[other_g] == raw_mins[*g] &&
+ other_min_batch_indices[other_g] < min_batch_indices[*g]) {
+ min_offsets[*g] = other_min_offsets[other_g];
+ min_batch_indices[*g] = other_min_batch_indices[other_g];
+ }
+ if (other_raw_maxes[other_g] > raw_maxes[*g]) {
+ raw_maxes[*g] = other_raw_maxes[other_g];
+ max_offsets[*g] = other_max_offsets[other_g];
+ max_batch_indices[*g] = other_max_batch_indices[other_g];
+ } else if (other_raw_maxes[other_g] == raw_maxes[*g] &&
+ other_max_batch_indices[other_g] < max_batch_indices[*g]) {
+ max_offsets[*g] = other_max_offsets[other_g];
+ max_batch_indices[*g] = other_max_batch_indices[other_g];
+ }
+
+ if (BitUtil::GetBit(other->has_values_.data(), other_g)) {
+ BitUtil::SetBit(has_values_.mutable_data(), *g);
+ }
+ if (BitUtil::GetBit(other->has_nulls_.data(), other_g)) {
+ BitUtil::SetBit(has_nulls_.mutable_data(), *g);
+ }
+ }
+ return Status::OK();
+ }
+
+ Result Finalize() override {
+ // aggregation for group is valid if there was at least one value in that group
+ ARROW_ASSIGN_OR_RAISE(auto null_bitmap, has_values_.Finish());
+
+ if (!options_.skip_nulls) {
+ // ... and there were no nulls in that group
+ ARROW_ASSIGN_OR_RAISE(auto has_nulls, has_nulls_.Finish());
+ arrow::internal::BitmapAndNot(null_bitmap->data(), 0, has_nulls->data(), 0,
+ num_groups_, 0, null_bitmap->mutable_data());
+ }
+
+ // Compute the actual row index
+ int64_t* min_offsets = min_offsets_.mutable_data();
+ int64_t* max_offsets = max_offsets_.mutable_data();
+ const int64_t* min_batch_indices = min_batch_indices_.mutable_data();
+ const int64_t* max_batch_indices = max_batch_indices_.mutable_data();
+ for (int64_t batch_idx = 0; static_cast(batch_idx) < batch_sizes_.size();
+ batch_idx++) {
+ for (int64_t i = 0; i < num_groups_; i++) {
+ if (batch_idx < min_batch_indices[i]) {
+ min_offsets[i] += batch_sizes_[batch_idx];
+ }
+ if (batch_idx < max_batch_indices[i]) {
+ max_offsets[i] += batch_sizes_[batch_idx];
+ }
+ }
+ }
+
+ auto mins = ArrayData::Make(int64(), num_groups_, {null_bitmap, nullptr});
+ auto maxes = ArrayData::Make(int64(), num_groups_, {std::move(null_bitmap), nullptr});
+ ARROW_ASSIGN_OR_RAISE(mins->buffers[1], min_offsets_.Finish());
+ ARROW_ASSIGN_OR_RAISE(maxes->buffers[1], max_offsets_.Finish());
+
+ return ArrayData::Make(out_type(), num_groups_, {nullptr},
+ {std::move(mins), std::move(maxes)});
+ }
+
+ std::shared_ptr out_type() const override {
+ return struct_({field("min", int64()), field("max", int64())});
+ }
+
+ int64_t num_groups_;
+ TypedBufferBuilder mins_, maxes_;
+ TypedBufferBuilder min_offsets_, min_batch_indices_, max_offsets_,
+ max_batch_indices_;
+ TypedBufferBuilder has_values_, has_nulls_;
+ std::unordered_map batch_sizes_;
+ ScalarAggregateOptions options_;
+};
+
+struct GroupedArgMinMaxFactory {
+ template
+ enable_if_number Visit(const T&) {
+ kernel = MakeOrderDependentKernel(std::move(argument_type),
+ HashAggregateInit>);
+ return Status::OK();
+ }
+
+ Status Visit(const HalfFloatType& type) {
+ return Status::NotImplemented("Computing argmin/argmax of data of type ", type);
+ }
+
+ Status Visit(const DataType& type) {
+ return Status::NotImplemented("Computing argmin/argmax of data of type ", type);
+ }
+
+ static Result Make(const std::shared_ptr& type) {
+ GroupedArgMinMaxFactory factory;
+ factory.argument_type = InputType::Array(type);
+ RETURN_NOT_OK(VisitTypeInline(*type, &factory));
+ return std::move(factory.kernel);
+ }
+
+ HashAggregateKernel kernel;
+ InputType argument_type;
+};
+
// ----------------------------------------------------------------------
// Any/All implementation
@@ -1832,10 +2045,19 @@ Result> GetKernels(
for (size_t i = 0; i < aggregates.size(); ++i) {
ARROW_ASSIGN_OR_RAISE(auto function,
ctx->func_registry()->GetFunction(aggregates[i].function));
- ARROW_ASSIGN_OR_RAISE(
- const Kernel* kernel,
- function->DispatchExact({in_descrs[i], ValueDescr::Array(uint32())}));
- kernels[i] = static_cast(kernel);
+ if (function->arity().num_args == 3) {
+ // Order-dependent kernel
+ ARROW_ASSIGN_OR_RAISE(
+ const Kernel* kernel,
+ function->DispatchExact(
+ {in_descrs[i], ValueDescr::Array(uint32()), ValueDescr::Scalar(int64())}));
+ kernels[i] = static_cast(kernel);
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ const Kernel* kernel,
+ function->DispatchExact({in_descrs[i], ValueDescr::Array(uint32())}));
+ kernels[i] = static_cast(kernel);
+ }
}
return kernels;
}
@@ -2128,6 +2350,14 @@ const FunctionDoc hash_min_max_doc{
{"array", "group_id_array"},
"ScalarAggregateOptions"};
+const FunctionDoc hash_arg_min_max_doc{
+ "Compute the indices of the minimum and maximum values of a numeric array",
+ ("If there are duplicate values, the least index is taken.\n"
+ "Null values are ignored by default.\n"
+ "This can be changed through ScalarAggregateOptions."),
+ {"array", "group_id_array", "batch_index_tag"},
+ "ScalarAggregateOptions"};
+
const FunctionDoc hash_any_doc{"Test whether any element evaluates to true",
("Null values are ignored."),
{"array", "group_id_array"}};
@@ -2233,6 +2463,16 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}
+ {
+ static auto default_scalar_aggregate_options = ScalarAggregateOptions::Defaults();
+ auto func = std::make_shared(
+ "hash_arg_min_max", Arity::Ternary(), &hash_arg_min_max_doc,
+ &default_scalar_aggregate_options);
+ DCHECK_OK(
+ AddHashAggKernels(NumericTypes(), GroupedArgMinMaxFactory::Make, func.get()));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
{
auto func = std::make_shared("hash_any", Arity::Binary(),
&hash_any_doc);
From c5c8f6f1dfb64f0cd46aafebc06cea96726acc00 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 10 Aug 2021 10:48:32 -0400
Subject: [PATCH 3/6] ARROW-13540: [C++] Fix undefined behavior
---
cpp/src/arrow/compute/exec/order_by_node.cc | 12 +++++++-----
1 file changed, 7 insertions(+), 5 deletions(-)
diff --git a/cpp/src/arrow/compute/exec/order_by_node.cc b/cpp/src/arrow/compute/exec/order_by_node.cc
index e0f2445796e..22070218cbc 100644
--- a/cpp/src/arrow/compute/exec/order_by_node.cc
+++ b/cpp/src/arrow/compute/exec/order_by_node.cc
@@ -33,8 +33,9 @@ using arrow::internal::checked_cast;
// Simple in-memory sort node. Accumulates all data, then sorts and
// emits output batches in order.
struct OrderByNode final : public ExecNode {
- OrderByNode(ExecPlan* plan, std::vector inputs, SortOptions sort_options)
- : ExecNode(plan, std::move(inputs), {"target"}, inputs[0]->output_schema(),
+ OrderByNode(ExecPlan* plan, std::vector inputs,
+ std::shared_ptr output_schema, SortOptions sort_options)
+ : ExecNode(plan, std::move(inputs), {"target"}, std::move(output_schema),
/*num_outputs=*/1),
sort_options_(std::move(sort_options)) {}
@@ -49,10 +50,11 @@ struct OrderByNode final : public ExecNode {
fields.reserve((order_by_options.sort_options.sort_keys.size()));
for (const auto& key : order_by_options.sort_options.sort_keys)
fields.push_back(key.name);
- RETURN_NOT_OK(inputs[0]->output_schema()->CanReferenceFieldsByNames(fields));
+ auto output_schema = inputs[0]->output_schema();
+ RETURN_NOT_OK(output_schema->CanReferenceFieldsByNames(fields));
- return plan->EmplaceNode(plan, std::move(inputs),
- order_by_options.sort_options);
+ return plan->EmplaceNode(
+ plan, std::move(inputs), std::move(output_schema), order_by_options.sort_options);
}
Status StartProducing() override {
From 554f094343a7acc4160073dd53975ba3e2da2e8a Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 10 Aug 2021 22:19:48 -0400
Subject: [PATCH 4/6] ARROW-13540: [C++] Convert back to sink node
---
cpp/src/arrow/CMakeLists.txt | 1 -
cpp/src/arrow/compute/exec/aggregate_node.cc | 14 +-
cpp/src/arrow/compute/exec/options.h | 28 +-
cpp/src/arrow/compute/exec/order_by_node.cc | 139 ----------
cpp/src/arrow/compute/exec/plan_test.cc | 55 +---
cpp/src/arrow/compute/exec/sink_node.cc | 95 ++++---
.../arrow/compute/kernels/hash_aggregate.cc | 248 +-----------------
7 files changed, 82 insertions(+), 498 deletions(-)
delete mode 100644 cpp/src/arrow/compute/exec/order_by_node.cc
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index 90fb2fae093..308ee49972c 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -374,7 +374,6 @@ if(ARROW_COMPUTE)
compute/exec/exec_plan.cc
compute/exec/expression.cc
compute/exec/filter_node.cc
- compute/exec/order_by_node.cc
compute/exec/project_node.cc
compute/exec/source_node.cc
compute/exec/sink_node.cc
diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc b/cpp/src/arrow/compute/exec/aggregate_node.cc
index 261e69c956d..de9078cd07e 100644
--- a/cpp/src/arrow/compute/exec/aggregate_node.cc
+++ b/cpp/src/arrow/compute/exec/aggregate_node.cc
@@ -362,17 +362,9 @@ struct GroupByNode : ExecNode {
KernelContext kernel_ctx{ctx_};
kernel_ctx.SetState(state->agg_states[i].get());
- ExecBatch agg_batch;
- if (agg_kernels_[i]->signature->in_types().size() == 2) {
- ARROW_ASSIGN_OR_RAISE(
- agg_batch, ExecBatch::Make({batch.values[agg_src_field_ids_[i]], id_batch}));
- } else {
- // Order-dependent-kernel; assume an upstream OrderByNode has
- // placed the batch index as the last value
- ARROW_ASSIGN_OR_RAISE(
- agg_batch, ExecBatch::Make({batch.values[agg_src_field_ids_[i]], id_batch,
- batch.values.back()}));
- }
+ ARROW_ASSIGN_OR_RAISE(
+ auto agg_batch,
+ ExecBatch::Make({batch.values[agg_src_field_ids_[i]], id_batch}));
RETURN_NOT_OK(agg_kernels_[i]->resize(&kernel_ctx, state->grouper->num_groups()));
RETURN_NOT_OK(agg_kernels_[i]->consume(&kernel_ctx, agg_batch));
diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h
index 7ea21d707ae..5732b78aada 100644
--- a/cpp/src/arrow/compute/exec/options.h
+++ b/cpp/src/arrow/compute/exec/options.h
@@ -64,19 +64,6 @@ class ARROW_EXPORT FilterNodeOptions : public ExecNodeOptions {
Expression filter_expression;
};
-/// \brief Make a node which sorts rows passed through it
-///
-/// All batches pushed to this node will be accumulated, then sorted, by the given
-/// fields. Then sorted batches will be pushed to the next node, along a tag
-/// indicating the absolute order of the batches.
-class ARROW_EXPORT OrderByNodeOptions : public ExecNodeOptions {
- public:
- explicit OrderByNodeOptions(SortOptions sort_options)
- : sort_options(std::move(sort_options)) {}
-
- SortOptions sort_options;
-};
-
/// \brief Make a node which executes expressions on input batches, producing new batches.
///
/// Each expression will be evaluated against each batch which is pushed to
@@ -125,5 +112,20 @@ class ARROW_EXPORT SinkNodeOptions : public ExecNodeOptions {
std::function>()>* generator;
};
+/// \brief Make a node which sorts rows passed through it
+///
+/// All batches pushed to this node will be accumulated, then sorted, by the given
+/// fields. Then sorted batches will be pushed to the next node, along a tag
+/// indicating the absolute order of the batches.
+class ARROW_EXPORT OrderBySinkNodeOptions : public SinkNodeOptions {
+ public:
+ explicit OrderBySinkNodeOptions(
+ SortOptions sort_options,
+ std::function>()>* generator)
+ : SinkNodeOptions(generator), sort_options(std::move(sort_options)) {}
+
+ SortOptions sort_options;
+};
+
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/exec/order_by_node.cc b/cpp/src/arrow/compute/exec/order_by_node.cc
deleted file mode 100644
index 22070218cbc..00000000000
--- a/cpp/src/arrow/compute/exec/order_by_node.cc
+++ /dev/null
@@ -1,139 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-#include "arrow/compute/exec/exec_plan.h"
-
-#include "arrow/compute/api_vector.h"
-#include "arrow/compute/exec/options.h"
-#include "arrow/compute/exec/util.h"
-#include "arrow/table.h"
-#include "arrow/util/future.h"
-#include "arrow/util/logging.h"
-
-namespace arrow {
-
-namespace compute {
-
-using arrow::internal::checked_cast;
-
-// Simple in-memory sort node. Accumulates all data, then sorts and
-// emits output batches in order.
-struct OrderByNode final : public ExecNode {
- OrderByNode(ExecPlan* plan, std::vector inputs,
- std::shared_ptr output_schema, SortOptions sort_options)
- : ExecNode(plan, std::move(inputs), {"target"}, std::move(output_schema),
- /*num_outputs=*/1),
- sort_options_(std::move(sort_options)) {}
-
- const char* kind_name() override { return "OrderByNode"; }
-
- static Result Make(ExecPlan* plan, std::vector inputs,
- const ExecNodeOptions& options) {
- RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "OrderByNode"));
-
- const auto& order_by_options = checked_cast(options);
- std::vector fields;
- fields.reserve((order_by_options.sort_options.sort_keys.size()));
- for (const auto& key : order_by_options.sort_options.sort_keys)
- fields.push_back(key.name);
- auto output_schema = inputs[0]->output_schema();
- RETURN_NOT_OK(output_schema->CanReferenceFieldsByNames(fields));
-
- return plan->EmplaceNode(
- plan, std::move(inputs), std::move(output_schema), order_by_options.sort_options);
- }
-
- Status StartProducing() override {
- finished_ = Future<>::Make();
- return Status::OK();
- }
-
- void PauseProducing(ExecNode* output) override {}
-
- void ResumeProducing(ExecNode* output) override {}
-
- void StopProducing(ExecNode* output) override {
- DCHECK_EQ(output, outputs_[0]);
- StopProducing();
- }
-
- void StopProducing() override { inputs_[0]->StopProducing(this); }
-
- Future<> finished() override { return finished_; }
-
- void InputReceived(ExecNode* input, int seq, ExecBatch batch) override {
- DCHECK_EQ(input, inputs_[0]);
-
- // Accumulate data
- auto maybe_batch = batch.ToRecordBatch(inputs_[0]->output_schema(),
- plan()->exec_context()->memory_pool());
- if (ErrorIfNotOk(maybe_batch.status())) return;
- batches_.push_back(maybe_batch.MoveValueUnsafe());
-
- if (input_counter_.Increment()) {
- ErrorIfNotOk(Finish());
- }
- }
-
- void ErrorReceived(ExecNode* input, Status error) override {
- DCHECK_EQ(input, inputs_[0]);
- outputs_[0]->ErrorReceived(this, std::move(error));
- }
-
- void InputFinished(ExecNode* input, int seq_stop) override {
- if (input_counter_.SetTotal(seq_stop)) {
- ErrorIfNotOk(Finish());
- }
- }
-
- private:
- Status Finish() {
- ARROW_ASSIGN_OR_RAISE(
- auto table,
- Table::FromRecordBatches(inputs_[0]->output_schema(), std::move(batches_)));
- ARROW_ASSIGN_OR_RAISE(auto indices,
- SortIndices(table, sort_options_, plan()->exec_context()));
- ARROW_ASSIGN_OR_RAISE(auto sorted, Take(table, indices, TakeOptions::NoBoundsCheck(),
- plan()->exec_context()));
-
- TableBatchReader reader(*sorted.table());
- int64_t count = 0;
- while (true) {
- std::shared_ptr batch;
- RETURN_NOT_OK(reader.ReadNext(&batch));
- if (!batch) break;
- ExecBatch exec_batch(*batch);
- exec_batch.values.emplace_back(count);
- outputs_[0]->InputReceived(this, static_cast(count), std::move(exec_batch));
- count++;
- }
-
- outputs_[0]->InputFinished(this, static_cast(count));
- finished_.MarkFinished();
- return Status::OK();
- }
-
- SortOptions sort_options_;
- std::vector> batches_;
- AtomicCounter input_counter_;
- Future<> finished_;
-};
-
-ExecFactoryRegistry::AddOnLoad kRegisterOrderBy("order_by", OrderByNode::Make);
-
-} // namespace compute
-} // namespace arrow
diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc
index 0b05840a9c2..d4478e08bed 100644
--- a/cpp/src/arrow/compute/exec/plan_test.cc
+++ b/cpp/src/arrow/compute/exec/plan_test.cc
@@ -323,8 +323,7 @@ TEST(ExecPlanExecution, SourceOrderBy) {
{
{"source", SourceNodeOptions{basic_data.schema,
basic_data.gen(parallel, slow)}},
- {"order_by", OrderByNodeOptions(options)},
- {"reorder", SinkNodeOptions{&sink_gen}},
+ {"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}},
})
.AddToPlan(plan.get()));
@@ -408,8 +407,7 @@ TEST(ExecPlanExecution, StressSourceOrderBy) {
{
{"source", SourceNodeOptions{random_data.schema,
random_data.gen(parallel, slow)}},
- {"order_by", OrderByNodeOptions(options)},
- {"reorder", SinkNodeOptions{&sink_gen}},
+ {"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}},
})
.AddToPlan(plan.get()));
@@ -649,8 +647,7 @@ TEST(ExecPlanExecution, SourceFilterProjectGroupedSumOrderBy) {
/*keys=*/{"str"}}},
{"filter", FilterNodeOptions{greater(field_ref("sum(multiply(i32, 2))"),
literal(10 * batch_multiplicity))}},
- {"order_by", OrderByNodeOptions{options}},
- {"reorder", SinkNodeOptions{&sink_gen}},
+ {"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}},
})
.AddToPlan(plan.get()));
@@ -661,52 +658,6 @@ TEST(ExecPlanExecution, SourceFilterProjectGroupedSumOrderBy) {
}
}
-TEST(ExecPlanExecution, SourceOrderByGroupSink) {
- for (bool parallel : {false, true}) {
- SCOPED_TRACE(parallel ? "parallel/merged" : "serial");
-
- int batch_multiplicity = parallel ? 1000 : 1;
- auto input = MakeGroupableBatches(/*multiplicity=*/batch_multiplicity);
-
- ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
- AsyncGenerator> sink_gen;
-
- SortOptions options({SortKey("str", SortOrder::Ascending)});
- ASSERT_OK(Declaration::Sequence(
- {
- {"source", SourceNodeOptions{input.schema,
- input.gen(parallel, /*slow=*/false)}},
- {"order_by", OrderByNodeOptions{options}},
- {"aggregate", AggregateNodeOptions{
- /*aggregates=*/{{"hash_arg_min_max", nullptr}},
- /*targets=*/{"i32"},
- /*names=*/{"arg_min_max(i32)"},
- /*keys=*/{"str"}}},
- {"sink", SinkNodeOptions{&sink_gen}},
- })
- .AddToPlan(plan.get()));
-
- ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
- Finishes(ResultWith(::testing::ElementsAreArray(
- {ExecBatchFromJSON({struct_({
- field("min", int64()),
- field("max", int64()),
- }),
- utf8()},
- parallel ?
- R"([
- [{"min": 4, "max": 0}, "alfa"],
- [{"min": 5001, "max": 5000}, "beta"],
- [{"min": 7000, "max": 7001}, "gama"]
-])"
- : R"([
- [{"min": 4, "max": 0}, "alfa"],
- [{"min": 6, "max": 5}, "beta"],
- [{"min": 7, "max": 8}, "gama"]
-])")}))));
- }
-}
-
TEST(ExecPlanExecution, SourceScalarAggSink) {
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
AsyncGenerator> sink_gen;
diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc
index 5388d81df62..b5891c6a446 100644
--- a/cpp/src/arrow/compute/exec/sink_node.cc
+++ b/cpp/src/arrow/compute/exec/sink_node.cc
@@ -21,6 +21,7 @@
#include
#include
+#include "arrow/compute/api_vector.h"
#include "arrow/compute/exec.h"
#include "arrow/compute/exec/expression.h"
#include "arrow/compute/exec/options.h"
@@ -28,6 +29,7 @@
#include "arrow/compute/exec_internal.h"
#include "arrow/datum.h"
#include "arrow/result.h"
+#include "arrow/table.h"
#include "arrow/util/async_generator.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/future.h"
@@ -149,72 +151,89 @@ class SinkNode : public ExecNode {
PushGenerator>::Producer producer_;
};
-// A node that reorders inputs according to a tag. To be paired with OrderByNode.
-struct ReorderNode final : public SinkNode {
- ReorderNode(ExecPlan* plan, std::vector inputs,
- AsyncGenerator>* generator)
- : SinkNode(plan, std::move(inputs), generator) {}
+// A sink node that accumulates inputs, then sorts them before emitting them.
+struct OrderBySinkNode final : public SinkNode {
+ OrderBySinkNode(ExecPlan* plan, std::vector inputs, SortOptions sort_options,
+ AsyncGenerator>* generator)
+ : SinkNode(plan, std::move(inputs), generator),
+ sort_options_(std::move(sort_options)) {}
- const char* kind_name() override { return "ReorderNode"; }
+ const char* kind_name() override { return "OrderBySinkNode"; }
static Result Make(ExecPlan* plan, std::vector inputs,
const ExecNodeOptions& options) {
- RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "ReorderNode"));
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "OrderBySinkNode"));
- const auto& sink_options = checked_cast(options);
- return plan->EmplaceNode(plan, std::move(inputs),
- sink_options.generator);
+ const auto& sink_options = checked_cast(options);
+ return plan->EmplaceNode(
+ plan, std::move(inputs), sink_options.sort_options, sink_options.generator);
}
void InputReceived(ExecNode* input, int seq, ExecBatch batch) override {
DCHECK_EQ(input, inputs_[0]);
+ // Accumulate data
+ {
+ std::unique_lock lock(mutex_);
+ auto maybe_batch = batch.ToRecordBatch(inputs_[0]->output_schema(),
+ plan()->exec_context()->memory_pool());
+ if (ErrorIfNotOk(maybe_batch.status())) return;
+ batches_.push_back(maybe_batch.MoveValueUnsafe());
+ }
+
if (input_counter_.Increment()) {
Finish();
- return;
- }
- std::unique_lock lock(mutex_);
- const auto& tag_scalar = *batch.values.back().scalar();
- const int64_t tag = checked_cast(tag_scalar).value;
- batch.values.pop_back();
- PushAvailable();
- if (tag == next_batch_index_) {
- next_batch_index_++;
- producer_.Push(std::move(batch));
- } else {
- batches_.emplace(tag, std::move(batch));
}
}
protected:
- void PushAvailable() {
- decltype(batches_)::iterator it;
- while ((it = batches_.find(next_batch_index_)) != batches_.end()) {
- auto batch = std::move(it->second);
- bool did_push = producer_.Push(std::move(batch));
- batches_.erase(it);
- // producer was Closed already
- if (!did_push) return;
- next_batch_index_++;
- }
+ Result> SortData() {
+ std::unique_lock lock(mutex_);
+ ARROW_ASSIGN_OR_RAISE(
+ auto table,
+ Table::FromRecordBatches(inputs_[0]->output_schema(), std::move(batches_)));
+ ARROW_ASSIGN_OR_RAISE(auto indices,
+ SortIndices(table, sort_options_, plan()->exec_context()));
+ ARROW_ASSIGN_OR_RAISE(auto sorted, Take(table, indices, TakeOptions::NoBoundsCheck(),
+ plan()->exec_context()));
+ return sorted.table();
}
void Finish() override {
- {
- std::unique_lock lock(mutex_);
- PushAvailable();
+ auto maybe_sorted = SortData();
+ if (ErrorIfNotOk(maybe_sorted.status())) {
+ producer_.Push(maybe_sorted.status());
+ SinkNode::Finish();
+ return;
+ }
+ auto sorted = maybe_sorted.MoveValueUnsafe();
+
+ TableBatchReader reader(*sorted);
+ while (true) {
+ std::shared_ptr batch;
+ auto status = reader.ReadNext(&batch);
+ if (!status.ok()) {
+ producer_.Push(std::move(status));
+ SinkNode::Finish();
+ return;
+ }
+ if (!batch) break;
+ bool did_push = producer_.Push(ExecBatch(*batch));
+ if (!did_push) break; // producer_ was Closed already
}
+
SinkNode::Finish();
}
private:
- std::unordered_map batches_;
+ SortOptions sort_options_;
std::mutex mutex_;
- int64_t next_batch_index_ = 0;
+ std::vector> batches_;
};
ExecFactoryRegistry::AddOnLoad kRegisterSink("sink", SinkNode::Make);
-ExecFactoryRegistry::AddOnLoad kRegisterReorder("reorder", ReorderNode::Make);
+ExecFactoryRegistry::AddOnLoad kRegisterOrderBySink("order_by_sink",
+ OrderBySinkNode::Make);
} // namespace
} // namespace compute
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
index f9bf93116cc..b3d602a89ac 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
@@ -824,18 +824,6 @@ Status AddHashAggKernels(
return Status::OK();
}
-HashAggregateKernel MakeOrderDependentKernel(InputType argument_type, KernelInit init) {
- HashAggregateKernel kernel = MakeKernel(argument_type, init);
- kernel.signature = KernelSignature::Make(
- {std::move(argument_type), InputType::Array(Type::UINT32),
- InputType::Scalar(Type::INT64)},
- OutputType(
- [](KernelContext* ctx, const std::vector&) -> Result {
- return checked_cast(ctx->state())->out_type();
- }));
- return kernel;
-}
-
// ----------------------------------------------------------------------
// Count implementation
@@ -1712,207 +1700,6 @@ struct GroupedMinMaxFactory {
InputType argument_type;
};
-// ----------------------------------------------------------------------
-// ArgMinMax implementation
-
-template
-struct GroupedArgMinMaxImpl : public GroupedAggregator {
- using CType = typename TypeTraits::CType;
-
- Status Init(ExecContext* ctx, const FunctionOptions* options) override {
- options_ = *checked_cast(options);
- mins_ = TypedBufferBuilder(ctx->memory_pool());
- maxes_ = TypedBufferBuilder(ctx->memory_pool());
- min_offsets_ = TypedBufferBuilder(ctx->memory_pool());
- max_offsets_ = TypedBufferBuilder(ctx->memory_pool());
- min_batch_indices_ = TypedBufferBuilder(ctx->memory_pool());
- max_batch_indices_ = TypedBufferBuilder(ctx->memory_pool());
- has_values_ = TypedBufferBuilder(ctx->memory_pool());
- has_nulls_ = TypedBufferBuilder(ctx->memory_pool());
- return Status::OK();
- }
-
- Status Resize(int64_t new_num_groups) override {
- auto added_groups = new_num_groups - num_groups_;
- num_groups_ = new_num_groups;
- RETURN_NOT_OK(mins_.Append(added_groups, AntiExtrema::anti_min()));
- RETURN_NOT_OK(maxes_.Append(added_groups, AntiExtrema::anti_max()));
- RETURN_NOT_OK(min_offsets_.Append(added_groups, -1));
- RETURN_NOT_OK(max_offsets_.Append(added_groups, -1));
- RETURN_NOT_OK(min_batch_indices_.Append(added_groups, -1));
- RETURN_NOT_OK(max_batch_indices_.Append(added_groups, -1));
- RETURN_NOT_OK(has_values_.Append(added_groups, false));
- RETURN_NOT_OK(has_nulls_.Append(added_groups, false));
- return Status::OK();
- }
-
- Status Consume(const ExecBatch& batch) override {
- DCHECK_EQ(3, batch.num_values());
- auto g = batch[1].array()->GetValues(1);
- const Scalar& tag_scalar = *batch.values.back().scalar();
- const int64_t batch_index = UnboxScalar::Unbox(tag_scalar);
- auto raw_mins = reinterpret_cast(mins_.mutable_data());
- auto raw_maxes = reinterpret_cast(maxes_.mutable_data());
- auto max_offsets = max_offsets_.mutable_data();
- auto max_batch_indices = max_batch_indices_.mutable_data();
- auto min_offsets = min_offsets_.mutable_data();
- auto min_batch_indices = min_batch_indices_.mutable_data();
- batch_sizes_.emplace(batch_index, batch.length);
-
- int64_t index = 0;
- VisitArrayDataInline(
- *batch[0].array(),
- [&](CType val) {
- if (val > raw_maxes[*g] || max_batch_indices[*g] < 0) {
- raw_maxes[*g] = val;
- max_offsets[*g] = index;
- max_batch_indices[*g] = batch_index;
- }
- // TODO: test an array that contains the antiextreme
- if (val < raw_mins[*g] || min_batch_indices[*g] < 0) {
- raw_mins[*g] = val;
- min_offsets[*g] = index;
- min_batch_indices[*g] = batch_index;
- }
- BitUtil::SetBit(has_values_.mutable_data(), *g++);
- index++;
- },
- [&] {
- BitUtil::SetBit(has_nulls_.mutable_data(), *g++);
- index++;
- });
- return Status::OK();
- }
-
- Status Merge(GroupedAggregator&& raw_other,
- const ArrayData& group_id_mapping) override {
- auto other = checked_cast(&raw_other);
-
- batch_sizes_.insert(other->batch_sizes_.begin(), other->batch_sizes_.end());
-
- // TODO: go back and clean up these casts
- auto raw_mins = reinterpret_cast(mins_.mutable_data());
- auto min_offsets = min_offsets_.mutable_data();
- auto min_batch_indices = max_batch_indices_.mutable_data();
- auto raw_maxes = reinterpret_cast(maxes_.mutable_data());
- auto max_offsets = max_offsets_.mutable_data();
- auto max_batch_indices = max_batch_indices_.mutable_data();
-
- auto other_raw_mins = reinterpret_cast(other->mins_.data());
- auto other_min_offsets = other->min_offsets_.mutable_data();
- auto other_min_batch_indices = other->max_batch_indices_.mutable_data();
- auto other_raw_maxes = reinterpret_cast(other->maxes_.data());
- auto other_max_offsets = other->max_offsets_.mutable_data();
- auto other_max_batch_indices = other->max_batch_indices_.mutable_data();
-
- auto g = group_id_mapping.GetValues(1);
- for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
- if (other_raw_mins[other_g] < raw_mins[*g]) {
- raw_mins[*g] = other_raw_mins[other_g];
- min_offsets[*g] = other_min_offsets[other_g];
- min_batch_indices[*g] = other_min_batch_indices[other_g];
- } else if (other_raw_mins[other_g] == raw_mins[*g] &&
- other_min_batch_indices[other_g] < min_batch_indices[*g]) {
- min_offsets[*g] = other_min_offsets[other_g];
- min_batch_indices[*g] = other_min_batch_indices[other_g];
- }
- if (other_raw_maxes[other_g] > raw_maxes[*g]) {
- raw_maxes[*g] = other_raw_maxes[other_g];
- max_offsets[*g] = other_max_offsets[other_g];
- max_batch_indices[*g] = other_max_batch_indices[other_g];
- } else if (other_raw_maxes[other_g] == raw_maxes[*g] &&
- other_max_batch_indices[other_g] < max_batch_indices[*g]) {
- max_offsets[*g] = other_max_offsets[other_g];
- max_batch_indices[*g] = other_max_batch_indices[other_g];
- }
-
- if (BitUtil::GetBit(other->has_values_.data(), other_g)) {
- BitUtil::SetBit(has_values_.mutable_data(), *g);
- }
- if (BitUtil::GetBit(other->has_nulls_.data(), other_g)) {
- BitUtil::SetBit(has_nulls_.mutable_data(), *g);
- }
- }
- return Status::OK();
- }
-
- Result Finalize() override {
- // aggregation for group is valid if there was at least one value in that group
- ARROW_ASSIGN_OR_RAISE(auto null_bitmap, has_values_.Finish());
-
- if (!options_.skip_nulls) {
- // ... and there were no nulls in that group
- ARROW_ASSIGN_OR_RAISE(auto has_nulls, has_nulls_.Finish());
- arrow::internal::BitmapAndNot(null_bitmap->data(), 0, has_nulls->data(), 0,
- num_groups_, 0, null_bitmap->mutable_data());
- }
-
- // Compute the actual row index
- int64_t* min_offsets = min_offsets_.mutable_data();
- int64_t* max_offsets = max_offsets_.mutable_data();
- const int64_t* min_batch_indices = min_batch_indices_.mutable_data();
- const int64_t* max_batch_indices = max_batch_indices_.mutable_data();
- for (int64_t batch_idx = 0; static_cast(batch_idx) < batch_sizes_.size();
- batch_idx++) {
- for (int64_t i = 0; i < num_groups_; i++) {
- if (batch_idx < min_batch_indices[i]) {
- min_offsets[i] += batch_sizes_[batch_idx];
- }
- if (batch_idx < max_batch_indices[i]) {
- max_offsets[i] += batch_sizes_[batch_idx];
- }
- }
- }
-
- auto mins = ArrayData::Make(int64(), num_groups_, {null_bitmap, nullptr});
- auto maxes = ArrayData::Make(int64(), num_groups_, {std::move(null_bitmap), nullptr});
- ARROW_ASSIGN_OR_RAISE(mins->buffers[1], min_offsets_.Finish());
- ARROW_ASSIGN_OR_RAISE(maxes->buffers[1], max_offsets_.Finish());
-
- return ArrayData::Make(out_type(), num_groups_, {nullptr},
- {std::move(mins), std::move(maxes)});
- }
-
- std::shared_ptr out_type() const override {
- return struct_({field("min", int64()), field("max", int64())});
- }
-
- int64_t num_groups_;
- TypedBufferBuilder mins_, maxes_;
- TypedBufferBuilder min_offsets_, min_batch_indices_, max_offsets_,
- max_batch_indices_;
- TypedBufferBuilder has_values_, has_nulls_;
- std::unordered_map batch_sizes_;
- ScalarAggregateOptions options_;
-};
-
-struct GroupedArgMinMaxFactory {
- template
- enable_if_number Visit(const T&) {
- kernel = MakeOrderDependentKernel(std::move(argument_type),
- HashAggregateInit>);
- return Status::OK();
- }
-
- Status Visit(const HalfFloatType& type) {
- return Status::NotImplemented("Computing argmin/argmax of data of type ", type);
- }
-
- Status Visit(const DataType& type) {
- return Status::NotImplemented("Computing argmin/argmax of data of type ", type);
- }
-
- static Result Make(const std::shared_ptr& type) {
- GroupedArgMinMaxFactory factory;
- factory.argument_type = InputType::Array(type);
- RETURN_NOT_OK(VisitTypeInline(*type, &factory));
- return std::move(factory.kernel);
- }
-
- HashAggregateKernel kernel;
- InputType argument_type;
-};
-
// ----------------------------------------------------------------------
// Any/All implementation
@@ -2045,19 +1832,10 @@ Result> GetKernels(
for (size_t i = 0; i < aggregates.size(); ++i) {
ARROW_ASSIGN_OR_RAISE(auto function,
ctx->func_registry()->GetFunction(aggregates[i].function));
- if (function->arity().num_args == 3) {
- // Order-dependent kernel
- ARROW_ASSIGN_OR_RAISE(
- const Kernel* kernel,
- function->DispatchExact(
- {in_descrs[i], ValueDescr::Array(uint32()), ValueDescr::Scalar(int64())}));
- kernels[i] = static_cast(kernel);
- } else {
- ARROW_ASSIGN_OR_RAISE(
- const Kernel* kernel,
- function->DispatchExact({in_descrs[i], ValueDescr::Array(uint32())}));
- kernels[i] = static_cast(kernel);
- }
+ ARROW_ASSIGN_OR_RAISE(
+ const Kernel* kernel,
+ function->DispatchExact({in_descrs[i], ValueDescr::Array(uint32())}));
+ kernels[i] = static_cast(kernel);
}
return kernels;
}
@@ -2350,14 +2128,6 @@ const FunctionDoc hash_min_max_doc{
{"array", "group_id_array"},
"ScalarAggregateOptions"};
-const FunctionDoc hash_arg_min_max_doc{
- "Compute the indices of the minimum and maximum values of a numeric array",
- ("If there are duplicate values, the least index is taken.\n"
- "Null values are ignored by default.\n"
- "This can be changed through ScalarAggregateOptions."),
- {"array", "group_id_array", "batch_index_tag"},
- "ScalarAggregateOptions"};
-
const FunctionDoc hash_any_doc{"Test whether any element evaluates to true",
("Null values are ignored."),
{"array", "group_id_array"}};
@@ -2463,16 +2233,6 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}
- {
- static auto default_scalar_aggregate_options = ScalarAggregateOptions::Defaults();
- auto func = std::make_shared(
- "hash_arg_min_max", Arity::Ternary(), &hash_arg_min_max_doc,
- &default_scalar_aggregate_options);
- DCHECK_OK(
- AddHashAggKernels(NumericTypes(), GroupedArgMinMaxFactory::Make, func.get()));
- DCHECK_OK(registry->AddFunction(std::move(func)));
- }
-
{
auto func = std::make_shared("hash_any", Arity::Binary(),
&hash_any_doc);
From 7d19509da42c8e134755b28da77a58f331f7d752 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 11 Aug 2021 12:30:18 -0400
Subject: [PATCH 5/6] ARROW-13540: [C++] Refactor and clean up sink node
---
cpp/src/arrow/compute/exec/options.h | 3 +-
cpp/src/arrow/compute/exec/plan_test.cc | 17 +++------
cpp/src/arrow/compute/exec/sink_node.cc | 47 +++++++++++--------------
cpp/src/arrow/compute/exec/util.cc | 11 ++++++
cpp/src/arrow/compute/exec/util.h | 4 +++
5 files changed, 41 insertions(+), 41 deletions(-)
diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h
index 5732b78aada..acc79bdfdde 100644
--- a/cpp/src/arrow/compute/exec/options.h
+++ b/cpp/src/arrow/compute/exec/options.h
@@ -115,8 +115,7 @@ class ARROW_EXPORT SinkNodeOptions : public ExecNodeOptions {
/// \brief Make a node which sorts rows passed through it
///
/// All batches pushed to this node will be accumulated, then sorted, by the given
-/// fields. Then sorted batches will be pushed to the next node, along a tag
-/// indicating the absolute order of the batches.
+/// fields. Then sorted batches will be forwarded to the generator in sorted order.
class ARROW_EXPORT OrderBySinkNodeOptions : public SinkNodeOptions {
public:
explicit OrderBySinkNodeOptions(
diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc
index d4478e08bed..1b41715ddfd 100644
--- a/cpp/src/arrow/compute/exec/plan_test.cc
+++ b/cpp/src/arrow/compute/exec/plan_test.cc
@@ -25,6 +25,7 @@
#include "arrow/compute/exec/expression.h"
#include "arrow/compute/exec/options.h"
#include "arrow/compute/exec/test_util.h"
+#include "arrow/compute/exec/util.h"
#include "arrow/record_batch.h"
#include "arrow/table.h"
#include "arrow/testing/future_util.h"
@@ -37,6 +38,7 @@
#include "arrow/util/vector.h"
using testing::ElementsAre;
+using testing::ElementsAreArray;
using testing::HasSubstr;
using testing::Optional;
using testing::UnorderedElementsAreArray;
@@ -328,7 +330,7 @@ TEST(ExecPlanExecution, SourceOrderBy) {
.AddToPlan(plan.get()));
ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
- Finishes(ResultWith(::testing::ElementsAreArray(expected))));
+ Finishes(ResultWith(ElementsAreArray(expected))));
}
}
}
@@ -414,18 +416,9 @@ TEST(ExecPlanExecution, StressSourceOrderBy) {
// Check that data is sorted appropriately
ASSERT_FINISHES_OK_AND_ASSIGN(auto exec_batches,
StartAndCollect(plan.get(), sink_gen));
- RecordBatchVector batches, original_batches;
- for (const auto& batch : exec_batches) {
- ASSERT_OK_AND_ASSIGN(auto rb, batch.ToRecordBatch(input_schema));
- batches.push_back(std::move(rb));
- }
- for (const auto& batch : random_data.batches) {
- ASSERT_OK_AND_ASSIGN(auto rb, batch.ToRecordBatch(input_schema));
- original_batches.push_back(std::move(rb));
- }
- ASSERT_OK_AND_ASSIGN(auto actual, Table::FromRecordBatches(input_schema, batches));
+ ASSERT_OK_AND_ASSIGN(auto actual, TableFromExecBatches(input_schema, exec_batches));
ASSERT_OK_AND_ASSIGN(auto original,
- Table::FromRecordBatches(input_schema, original_batches));
+ TableFromExecBatches(input_schema, random_data.batches));
ASSERT_OK_AND_ASSIGN(auto sort_indices, SortIndices(original, options));
ASSERT_OK_AND_ASSIGN(auto expected, Take(original, sort_indices));
AssertTablesEqual(*actual, *expected.table());
diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc
index b5891c6a446..4b36687c2e0 100644
--- a/cpp/src/arrow/compute/exec/sink_node.cc
+++ b/cpp/src/arrow/compute/exec/sink_node.cc
@@ -187,41 +187,34 @@ struct OrderBySinkNode final : public SinkNode {
}
protected:
- Result> SortData() {
- std::unique_lock lock(mutex_);
- ARROW_ASSIGN_OR_RAISE(
- auto table,
- Table::FromRecordBatches(inputs_[0]->output_schema(), std::move(batches_)));
- ARROW_ASSIGN_OR_RAISE(auto indices,
- SortIndices(table, sort_options_, plan()->exec_context()));
- ARROW_ASSIGN_OR_RAISE(auto sorted, Take(table, indices, TakeOptions::NoBoundsCheck(),
- plan()->exec_context()));
- return sorted.table();
- }
-
- void Finish() override {
- auto maybe_sorted = SortData();
- if (ErrorIfNotOk(maybe_sorted.status())) {
- producer_.Push(maybe_sorted.status());
- SinkNode::Finish();
- return;
+ Status DoFinish() {
+ Datum sorted;
+ {
+ std::unique_lock lock(mutex_);
+ ARROW_ASSIGN_OR_RAISE(
+ auto table,
+ Table::FromRecordBatches(inputs_[0]->output_schema(), std::move(batches_)));
+ ARROW_ASSIGN_OR_RAISE(auto indices,
+ SortIndices(table, sort_options_, plan()->exec_context()));
+ ARROW_ASSIGN_OR_RAISE(sorted, Take(table, indices, TakeOptions::NoBoundsCheck(),
+ plan()->exec_context()));
}
- auto sorted = maybe_sorted.MoveValueUnsafe();
-
- TableBatchReader reader(*sorted);
+ TableBatchReader reader(*sorted.table());
while (true) {
std::shared_ptr batch;
- auto status = reader.ReadNext(&batch);
- if (!status.ok()) {
- producer_.Push(std::move(status));
- SinkNode::Finish();
- return;
- }
+ RETURN_NOT_OK(reader.ReadNext(&batch));
if (!batch) break;
bool did_push = producer_.Push(ExecBatch(*batch));
if (!did_push) break; // producer_ was Closed already
}
+ return Status::OK();
+ }
+ void Finish() override {
+ Status st = DoFinish();
+ if (ErrorIfNotOk(st)) {
+ producer_.Push(std::move(st));
+ }
SinkNode::Finish();
}
diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc
index eecc617c9c0..aad6dc3d587 100644
--- a/cpp/src/arrow/compute/exec/util.cc
+++ b/cpp/src/arrow/compute/exec/util.cc
@@ -18,6 +18,7 @@
#include "arrow/compute/exec/util.h"
#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/table.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/ubsan.h"
@@ -296,5 +297,15 @@ Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector& inpu
return Status::OK();
}
+Result> TableFromExecBatches(
+ const std::shared_ptr& schema, const std::vector& exec_batches) {
+ RecordBatchVector batches;
+ for (const auto& batch : exec_batches) {
+ ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema));
+ batches.push_back(std::move(rb));
+ }
+ return Table::FromRecordBatches(schema, batches);
+}
+
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h
index b7cf0aeaa5e..8bd6a3c5d62 100644
--- a/cpp/src/arrow/compute/exec/util.h
+++ b/cpp/src/arrow/compute/exec/util.h
@@ -188,6 +188,10 @@ ARROW_EXPORT
Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector& inputs,
int expected_num_inputs, const char* kind_name);
+ARROW_EXPORT
+Result> TableFromExecBatches(
+ const std::shared_ptr& schema, const std::vector& exec_batches);
+
class AtomicCounter {
public:
AtomicCounter() = default;
From ceea1d582227926c30dcbb57ebca6b864efbf2b0 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 11 Aug 2021 13:19:49 -0400
Subject: [PATCH 6/6] ARROW-13540: [C++] Fix a few missed things
---
cpp/src/arrow/compute/exec/plan_test.cc | 2 +-
cpp/src/arrow/compute/exec/sink_node.cc | 1 -
2 files changed, 1 insertion(+), 2 deletions(-)
diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc
index 1b41715ddfd..f4d81ace040 100644
--- a/cpp/src/arrow/compute/exec/plan_test.cc
+++ b/cpp/src/arrow/compute/exec/plan_test.cc
@@ -645,7 +645,7 @@ TEST(ExecPlanExecution, SourceFilterProjectGroupedSumOrderBy) {
.AddToPlan(plan.get()));
ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
- Finishes(ResultWith(::testing::ElementsAreArray({ExecBatchFromJSON(
+ Finishes(ResultWith(ElementsAreArray({ExecBatchFromJSON(
{int64(), utf8()}, parallel ? R"([[2000, "beta"], [3600, "alfa"]])"
: R"([[20, "beta"], [36, "alfa"]])")}))));
}
diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc
index 4b36687c2e0..4d9f82e582b 100644
--- a/cpp/src/arrow/compute/exec/sink_node.cc
+++ b/cpp/src/arrow/compute/exec/sink_node.cc
@@ -19,7 +19,6 @@
#include "arrow/compute/exec/exec_plan.h"
#include
-#include
#include "arrow/compute/api_vector.h"
#include "arrow/compute/exec.h"