Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ set(ARROW_SRCS
util/tdigest.cc
util/thread_pool.cc
util/time.cc
util/tracing.cc
util/trie.cc
util/unreachable.cc
util/uri.cc
Expand Down Expand Up @@ -389,6 +390,7 @@ if(ARROW_COMPUTE)
compute/exec/key_encode.cc
compute/exec/key_hash.cc
compute/exec/key_map.cc
compute/exec/options.cc
compute/exec/order_by_impl.cc
compute/exec/project_node.cc
compute/exec/sink_node.cc
Expand Down
91 changes: 84 additions & 7 deletions cpp/src/arrow/compute/exec/aggregate_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"
#include "arrow/util/thread_pool.h"
#include "arrow/util/tracing_internal.h"

namespace arrow {

Expand Down Expand Up @@ -165,7 +166,18 @@ class ScalarAggregateNode : public ExecNode {
const char* kind_name() const override { return "ScalarAggregateNode"; }

Status DoConsume(const ExecBatch& batch, size_t thread_index) {
util::tracing::Span span;
START_SPAN(span, "Consume",
{{"aggregate", ToStringExtra()},
{"node.label", label()},
{"batch.length", batch.length}});
for (size_t i = 0; i < kernels_.size(); ++i) {
util::tracing::Span span;
START_SPAN(span, aggs_[i].function,
{{"function.name", aggs_[i].function},
{"function.options",
aggs_[i].options ? aggs_[i].options->ToString() : "<NULLPTR>"},
{"function.kind", std::string(kind_name()) + "::Consume"}});
KernelContext batch_ctx{plan()->exec_context()};
batch_ctx.SetState(states_[i][thread_index].get());

Expand All @@ -176,6 +188,12 @@ class ScalarAggregateNode : public ExecNode {
}

void InputReceived(ExecNode* input, ExecBatch batch) override {
EVENT(span_, "InputReceived", {{"batch.length", batch.length}});
util::tracing::Span span;
START_SPAN_WITH_PARENT(span, span_, "InputReceived",
{{"aggregate", ToStringExtra()},
{"node.label", label()},
{"batch.length", batch.length}});
DCHECK_EQ(input, inputs_[0]);

auto thread_index = get_thread_index_();
Expand All @@ -188,35 +206,42 @@ class ScalarAggregateNode : public ExecNode {
}

void ErrorReceived(ExecNode* input, Status error) override {
EVENT(span_, "ErrorReceived", {{"error", error.message()}});
DCHECK_EQ(input, inputs_[0]);
outputs_[0]->ErrorReceived(this, std::move(error));
}

void InputFinished(ExecNode* input, int total_batches) override {
EVENT(span_, "InputFinished", {{"batches.length", total_batches}});
DCHECK_EQ(input, inputs_[0]);

if (input_counter_.SetTotal(total_batches)) {
ErrorIfNotOk(Finish());
}
}

Status StartProducing() override {
START_SPAN(span_, std::string(kind_name()) + ":" + label(),
{{"node.label", label()},
{"node.detail", ToString()},
{"node.kind", kind_name()}});
finished_ = Future<>::Make();
END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
// Scalar aggregates will only output a single batch
outputs_[0]->InputFinished(this, 1);
return Status::OK();
}

void PauseProducing(ExecNode* output) override {}
void PauseProducing(ExecNode* output) override { EVENT(span_, "PauseProducing"); }

void ResumeProducing(ExecNode* output) override {}
void ResumeProducing(ExecNode* output) override { EVENT(span_, "ResumeProducing"); }

void StopProducing(ExecNode* output) override {
DCHECK_EQ(output, outputs_[0]);
StopProducing();
}

void StopProducing() override {
EVENT(span_, "StopProducing");
if (input_counter_.Cancel()) {
finished_.MarkFinished();
}
Expand All @@ -235,10 +260,18 @@ class ScalarAggregateNode : public ExecNode {

private:
Status Finish() {
util::tracing::Span span;
START_SPAN(span, "Finish", {{"aggregate", ToStringExtra()}, {"node.label", label()}});
ExecBatch batch{{}, 1};
batch.values.resize(kernels_.size());

for (size_t i = 0; i < kernels_.size(); ++i) {
util::tracing::Span span;
START_SPAN(span, aggs_[i].function,
{{"function.name", aggs_[i].function},
{"function.options",
aggs_[i].options ? aggs_[i].options->ToString() : "<NULLPTR>"},
{"function.kind", std::string(kind_name()) + "::Finalize"}});
KernelContext ctx{plan()->exec_context()};
ARROW_ASSIGN_OR_RAISE(auto merged, ScalarAggregateKernel::MergeAll(
kernels_[i], &ctx, std::move(states_[i])));
Expand All @@ -250,7 +283,6 @@ class ScalarAggregateNode : public ExecNode {
return Status::OK();
}

Future<> finished_ = Future<>::MakeFinished();
const std::vector<int> target_field_ids_;
const std::vector<internal::Aggregate> aggs_;
const std::vector<const ScalarAggregateKernel*> kernels_;
Expand Down Expand Up @@ -358,6 +390,11 @@ class GroupByNode : public ExecNode {
const char* kind_name() const override { return "GroupByNode"; }

Status Consume(ExecBatch batch) {
util::tracing::Span span;
START_SPAN(span, "Consume",
{{"group_by", ToStringExtra()},
{"node.label", label()},
{"batch.length", batch.length}});
size_t thread_index = get_thread_index_();
if (thread_index >= local_states_.size()) {
return Status::IndexError("thread index ", thread_index, " is out of range [0, ",
Expand All @@ -379,6 +416,12 @@ class GroupByNode : public ExecNode {

// Execute aggregate kernels
for (size_t i = 0; i < agg_kernels_.size(); ++i) {
util::tracing::Span span;
START_SPAN(span, aggs_[i].function,
{{"function.name", aggs_[i].function},
{"function.options",
aggs_[i].options ? aggs_[i].options->ToString() : "<NULLPTR>"},
{"function.kind", std::string(kind_name()) + "::Consume"}});
KernelContext kernel_ctx{ctx_};
kernel_ctx.SetState(state->agg_states[i].get());

Expand All @@ -394,6 +437,8 @@ class GroupByNode : public ExecNode {
}

Status Merge() {
util::tracing::Span span;
START_SPAN(span, "Merge", {{"group_by", ToStringExtra()}, {"node.label", label()}});
ThreadLocalState* state0 = &local_states_[0];
for (size_t i = 1; i < local_states_.size(); ++i) {
ThreadLocalState* state = &local_states_[i];
Expand All @@ -406,6 +451,12 @@ class GroupByNode : public ExecNode {
state->grouper.reset();

for (size_t i = 0; i < agg_kernels_.size(); ++i) {
util::tracing::Span span;
START_SPAN(span, aggs_[i].function,
{{"function.name", aggs_[i].function},
{"function.options",
aggs_[i].options ? aggs_[i].options->ToString() : "<NULLPTR>"},
{"function.kind", std::string(kind_name()) + "::Merge"}});
KernelContext batch_ctx{ctx_};
DCHECK(state0->agg_states[i]);
batch_ctx.SetState(state0->agg_states[i].get());
Expand All @@ -420,6 +471,10 @@ class GroupByNode : public ExecNode {
}

Result<ExecBatch> Finalize() {
util::tracing::Span span;
START_SPAN(span, "Finalize",
{{"group_by", ToStringExtra()}, {"node.label", label()}});

ThreadLocalState* state = &local_states_[0];
// If we never got any batches, then state won't have been initialized
RETURN_NOT_OK(InitLocalStateIfNeeded(state));
Expand All @@ -429,6 +484,12 @@ class GroupByNode : public ExecNode {

// Aggregate fields come before key fields to match the behavior of GroupBy function
for (size_t i = 0; i < agg_kernels_.size(); ++i) {
util::tracing::Span span;
START_SPAN(span, aggs_[i].function,
{{"function.name", aggs_[i].function},
{"function.options",
aggs_[i].options ? aggs_[i].options->ToString() : "<NULLPTR>"},
{"function.kind", std::string(kind_name()) + "::Finalize"}});
KernelContext batch_ctx{ctx_};
batch_ctx.SetState(state->agg_states[i].get());
RETURN_NOT_OK(agg_kernels_[i]->finalize(&batch_ctx, &out_data.values[i]));
Expand Down Expand Up @@ -484,6 +545,13 @@ class GroupByNode : public ExecNode {
}

void InputReceived(ExecNode* input, ExecBatch batch) override {
EVENT(span_, "InputReceived", {{"batch.length", batch.length}});
util::tracing::Span span;
START_SPAN_WITH_PARENT(span, span_, "InputReceived",
{{"group_by", ToStringExtra()},
{"node.label", label()},
{"batch.length", batch.length}});

// bail if StopProducing was called
if (finished_.is_finished()) return;

Expand All @@ -497,12 +565,16 @@ class GroupByNode : public ExecNode {
}

void ErrorReceived(ExecNode* input, Status error) override {
EVENT(span_, "ErrorReceived", {{"error", error.message()}});

DCHECK_EQ(input, inputs_[0]);

outputs_[0]->ErrorReceived(this, std::move(error));
}

void InputFinished(ExecNode* input, int total_batches) override {
EVENT(span_, "InputFinished", {{"batches.length", total_batches}});

// bail if StopProducing was called
if (finished_.is_finished()) return;

Expand All @@ -514,17 +586,23 @@ class GroupByNode : public ExecNode {
}

Status StartProducing() override {
START_SPAN(span_, std::string(kind_name()) + ":" + label(),
{{"node.label", label()},
{"node.detail", ToString()},
{"node.kind", kind_name()}});
finished_ = Future<>::Make();
END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);

local_states_.resize(ThreadIndexer::Capacity());
return Status::OK();
}

void PauseProducing(ExecNode* output) override {}
void PauseProducing(ExecNode* output) override { EVENT(span_, "PauseProducing"); }

void ResumeProducing(ExecNode* output) override {}
void ResumeProducing(ExecNode* output) override { EVENT(span_, "ResumeProducing"); }

void StopProducing(ExecNode* output) override {
EVENT(span_, "StopProducing");
DCHECK_EQ(output, outputs_[0]);

ARROW_UNUSED(input_counter_.Cancel());
Expand Down Expand Up @@ -603,7 +681,6 @@ class GroupByNode : public ExecNode {
}

ExecContext* ctx_;
Future<> finished_ = Future<>::MakeFinished();

const std::vector<int> key_field_ids_;
const std::vector<int> agg_src_field_ids_;
Expand Down
Loading