-
Notifications
You must be signed in to change notification settings - Fork 4k
ARROW-13313: [C++][Compute] Add scalar aggregate node #10705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7fd1754
a47853d
21a9595
b092fed
d77a0a6
b153862
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,12 +18,17 @@ | |
| #include "arrow/compute/exec/exec_plan.h" | ||
|
|
||
| #include <mutex> | ||
| #include <thread> | ||
| #include <unordered_map> | ||
| #include <unordered_set> | ||
|
|
||
| #include "arrow/array/util.h" | ||
| #include "arrow/compute/api_vector.h" | ||
| #include "arrow/compute/exec.h" | ||
| #include "arrow/compute/exec/expression.h" | ||
| #include "arrow/compute/registry.h" | ||
| #include "arrow/datum.h" | ||
| #include "arrow/record_batch.h" | ||
| #include "arrow/result.h" | ||
| #include "arrow/util/async_generator.h" | ||
| #include "arrow/util/checked_cast.h" | ||
|
|
@@ -33,6 +38,7 @@ | |
| namespace arrow { | ||
|
|
||
| using internal::checked_cast; | ||
| using internal::checked_pointer_cast; | ||
|
|
||
| namespace compute { | ||
|
|
||
|
|
@@ -489,15 +495,23 @@ struct ProjectNode : ExecNode { | |
| }; | ||
|
|
||
| Result<ExecNode*> MakeProjectNode(ExecNode* input, std::string label, | ||
| std::vector<Expression> exprs) { | ||
| std::vector<Expression> exprs, | ||
| std::vector<std::string> names) { | ||
| FieldVector fields(exprs.size()); | ||
|
|
||
| if (names.size() == 0) { | ||
| names.resize(exprs.size()); | ||
| for (size_t i = 0; i < exprs.size(); ++i) { | ||
| names[i] = exprs[i].ToString(); | ||
| } | ||
| } | ||
|
|
||
| int i = 0; | ||
| for (auto& expr : exprs) { | ||
| if (!expr.IsBound()) { | ||
| ARROW_ASSIGN_OR_RAISE(expr, expr.Bind(*input->output_schema())); | ||
| } | ||
| fields[i] = field(expr.ToString(), expr.type()); | ||
| fields[i] = field(std::move(names[i]), expr.type()); | ||
| ++i; | ||
| } | ||
|
|
||
|
|
@@ -552,15 +566,16 @@ struct SinkNode : ExecNode { | |
| ++num_received_; | ||
| if (num_received_ == emit_stop_) { | ||
| lock.unlock(); | ||
| producer_.Push(std::move(batch)); | ||
| Finish(); | ||
| lock.lock(); | ||
| return; | ||
| } | ||
|
|
||
| if (emit_stop_ != -1) { | ||
| DCHECK_LE(seq_num, emit_stop_); | ||
| } | ||
| lock.unlock(); | ||
|
|
||
| lock.unlock(); | ||
| producer_.Push(std::move(batch)); | ||
| } | ||
|
|
||
|
|
@@ -574,8 +589,10 @@ struct SinkNode : ExecNode { | |
| void InputFinished(ExecNode* input, int seq_stop) override { | ||
| std::unique_lock<std::mutex> lock(mutex_); | ||
| emit_stop_ = seq_stop; | ||
| lock.unlock(); | ||
| Finish(); | ||
| if (num_received_ == emit_stop_) { | ||
| lock.unlock(); | ||
| Finish(); | ||
| } | ||
| } | ||
|
|
||
| private: | ||
|
|
@@ -601,5 +618,205 @@ AsyncGenerator<util::optional<ExecBatch>> MakeSinkNode(ExecNode* input, | |
| return out; | ||
| } | ||
|
|
||
| std::shared_ptr<RecordBatchReader> MakeGeneratorReader( | ||
| std::shared_ptr<Schema> schema, | ||
| std::function<Future<util::optional<ExecBatch>>()> gen, MemoryPool* pool) { | ||
| struct Impl : RecordBatchReader { | ||
| std::shared_ptr<Schema> schema() const override { return schema_; } | ||
|
|
||
| Status ReadNext(std::shared_ptr<RecordBatch>* record_batch) override { | ||
| ARROW_ASSIGN_OR_RAISE(auto batch, iterator_.Next()); | ||
| if (batch) { | ||
| ARROW_ASSIGN_OR_RAISE(*record_batch, batch->ToRecordBatch(schema_, pool_)); | ||
| } else { | ||
| *record_batch = IterationEnd<std::shared_ptr<RecordBatch>>(); | ||
| } | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| MemoryPool* pool_; | ||
| std::shared_ptr<Schema> schema_; | ||
| Iterator<util::optional<ExecBatch>> iterator_; | ||
| }; | ||
|
|
||
| auto out = std::make_shared<Impl>(); | ||
| out->pool_ = pool; | ||
| out->schema_ = std::move(schema); | ||
| out->iterator_ = MakeGeneratorIterator(std::move(gen)); | ||
| return out; | ||
| } | ||
|
|
||
| struct ScalarAggregateNode : ExecNode { | ||
| ScalarAggregateNode(ExecNode* input, std::string label, | ||
| std::shared_ptr<Schema> output_schema, | ||
| std::vector<const ScalarAggregateKernel*> kernels, | ||
| std::vector<std::vector<std::unique_ptr<KernelState>>> states) | ||
| : ExecNode(input->plan(), std::move(label), {input}, {"target"}, | ||
| /*output_schema=*/std::move(output_schema), | ||
| /*num_outputs=*/1), | ||
| kernels_(std::move(kernels)), | ||
| states_(std::move(states)) {} | ||
|
|
||
| const char* kind_name() override { return "ScalarAggregateNode"; } | ||
|
|
||
| Status DoConsume(const ExecBatch& batch, size_t thread_index) { | ||
| for (size_t i = 0; i < kernels_.size(); ++i) { | ||
| KernelContext batch_ctx{plan()->exec_context()}; | ||
| batch_ctx.SetState(states_[i][thread_index].get()); | ||
| ExecBatch single_column_batch{{batch.values[i]}, batch.length}; | ||
| RETURN_NOT_OK(kernels_[i]->consume(&batch_ctx, single_column_batch)); | ||
| } | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { | ||
| DCHECK_EQ(input, inputs_[0]); | ||
|
|
||
| std::unique_lock<std::mutex> lock(mutex_); | ||
| auto it = | ||
| thread_indices_.emplace(std::this_thread::get_id(), thread_indices_.size()).first; | ||
| ++num_received_; | ||
| auto thread_index = it->second; | ||
|
|
||
| lock.unlock(); | ||
|
|
||
| Status st = DoConsume(std::move(batch), thread_index); | ||
| if (!st.ok()) { | ||
| outputs_[0]->ErrorReceived(this, std::move(st)); | ||
| return; | ||
| } | ||
|
|
||
| lock.lock(); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This lock could probably be removed. We might want to make a note to measure this with micro benchmarks someday. Only one thread should be finishing anyways and the "what state blocks have we used" map could probably be a lock free structure.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. InputReceived(last batch) might be called concurrently with InputFinished, so those two must synchronize to ensure only one does the finishing. It'd certainly be helpful to introduce less clumsy control flow in these classes |
||
| st = MaybeFinish(&lock); | ||
| if (!st.ok()) { | ||
| outputs_[0]->ErrorReceived(this, std::move(st)); | ||
| } | ||
| } | ||
|
|
||
| void ErrorReceived(ExecNode* input, Status error) override { | ||
| DCHECK_EQ(input, inputs_[0]); | ||
| outputs_[0]->ErrorReceived(this, std::move(error)); | ||
| } | ||
|
|
||
| void InputFinished(ExecNode* input, int seq) override { | ||
| DCHECK_EQ(input, inputs_[0]); | ||
| std::unique_lock<std::mutex> lock(mutex_); | ||
| num_total_ = seq; | ||
| Status st = MaybeFinish(&lock); | ||
|
|
||
| if (!st.ok()) { | ||
| outputs_[0]->ErrorReceived(this, std::move(st)); | ||
| } | ||
| } | ||
|
|
||
| Status StartProducing() override { | ||
| finished_ = Future<>::Make(); | ||
| // Scalar aggregates will only output a single batch | ||
| outputs_[0]->InputFinished(this, 1); | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| void PauseProducing(ExecNode* output) override {} | ||
|
|
||
| void ResumeProducing(ExecNode* output) override {} | ||
|
|
||
| void StopProducing(ExecNode* output) override { | ||
| DCHECK_EQ(output, outputs_[0]); | ||
| StopProducing(); | ||
| } | ||
|
|
||
| void StopProducing() override { | ||
| inputs_[0]->StopProducing(this); | ||
| finished_.MarkFinished(); | ||
| } | ||
|
|
||
| Future<> finished() override { return finished_; } | ||
|
|
||
| private: | ||
| Status MaybeFinish(std::unique_lock<std::mutex>* lock) { | ||
| if (num_received_ != num_total_) return Status::OK(); | ||
|
|
||
| if (finished_.is_finished()) return Status::OK(); | ||
|
|
||
| ExecBatch batch{{}, 1}; | ||
| batch.values.resize(kernels_.size()); | ||
|
|
||
| for (size_t i = 0; i < kernels_.size(); ++i) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe someday in the future we could merge each kernel on its own thread but that can be for a future PR.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Merging scalar aggregates is pretty trivial so I'd guess we don't gain much with parallelization. Worth investigating in a follow up, though
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, In my head "merge" meant something more like a merge sort. I agree, if it's just summing up a sum/mean/etc. counter across the various states then I agree it's not necessary. |
||
| KernelContext ctx{plan()->exec_context()}; | ||
| ARROW_ASSIGN_OR_RAISE(auto merged, ScalarAggregateKernel::MergeAll( | ||
| kernels_[i], &ctx, std::move(states_[i]))); | ||
| RETURN_NOT_OK(kernels_[i]->finalize(&ctx, &batch.values[i])); | ||
| } | ||
| lock->unlock(); | ||
|
|
||
| outputs_[0]->InputReceived(this, 0, batch); | ||
|
|
||
| finished_.MarkFinished(); | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| Future<> finished_ = Future<>::MakeFinished(); | ||
| std::vector<const ScalarAggregateKernel*> kernels_; | ||
| std::vector<std::vector<std::unique_ptr<KernelState>>> states_; | ||
| std::unordered_map<std::thread::id, size_t> thread_indices_; | ||
| std::mutex mutex_; | ||
| int num_received_ = 0, num_total_; | ||
| }; | ||
|
|
||
| Result<ExecNode*> MakeScalarAggregateNode(ExecNode* input, std::string label, | ||
| std::vector<internal::Aggregate> aggregates) { | ||
| if (input->output_schema()->num_fields() != static_cast<int>(aggregates.size())) { | ||
| return Status::Invalid("Provided ", aggregates.size(), | ||
lidavidm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| " aggregates, expected one for each field of ", | ||
| input->output_schema()->ToString()); | ||
| } | ||
|
|
||
| auto exec_ctx = input->plan()->exec_context(); | ||
|
|
||
| std::vector<const ScalarAggregateKernel*> kernels(aggregates.size()); | ||
| std::vector<std::vector<std::unique_ptr<KernelState>>> states(kernels.size()); | ||
| FieldVector fields(kernels.size()); | ||
|
|
||
| for (size_t i = 0; i < kernels.size(); ++i) { | ||
| ARROW_ASSIGN_OR_RAISE(auto function, | ||
| exec_ctx->func_registry()->GetFunction(aggregates[i].function)); | ||
|
|
||
| if (function->kind() != Function::SCALAR_AGGREGATE) { | ||
| return Status::Invalid("Provided non ScalarAggregateFunction ", | ||
| aggregates[i].function); | ||
| } | ||
|
|
||
| auto in_type = ValueDescr::Array(input->output_schema()->fields()[i]->type()); | ||
|
|
||
| ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, function->DispatchExact({in_type})); | ||
| kernels[i] = static_cast<const ScalarAggregateKernel*>(kernel); | ||
|
|
||
| if (aggregates[i].options == nullptr) { | ||
| aggregates[i].options = function->default_options(); | ||
| } | ||
|
|
||
| KernelContext kernel_ctx{exec_ctx}; | ||
| states[i].resize(exec_ctx->executor() ? exec_ctx->executor()->GetCapacity() : 1); | ||
| RETURN_NOT_OK(Kernel::InitAll(&kernel_ctx, | ||
| KernelInitArgs{kernels[i], | ||
| { | ||
| in_type, | ||
| }, | ||
| aggregates[i].options}, | ||
| &states[i])); | ||
|
|
||
| // pick one to resolve the kernel signature | ||
| kernel_ctx.SetState(states[i][0].get()); | ||
| ARROW_ASSIGN_OR_RAISE( | ||
| auto descr, kernels[i]->signature->out_type().Resolve(&kernel_ctx, {in_type})); | ||
|
|
||
| fields[i] = field(aggregates[i].function, std::move(descr.type)); | ||
| } | ||
|
|
||
| return input->plan()->EmplaceNode<ScalarAggregateNode>( | ||
| input, std::move(label), schema(std::move(fields)), std::move(kernels), | ||
| std::move(states)); | ||
| } | ||
|
|
||
| } // namespace compute | ||
| } // namespace arrow | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: to implement something like ARROW-12710 (string concat aggregate kernel) we'll need to know the order of inputs in the kernels (or will have to feed results into the kernel in order) - how do we plan to handle that? Passing down seq and having each kernel reorder inputs itself, or perhaps with an upstream ExecNode that orders its inputs? This also applies to the group by node.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seqis not an indication of order, it's only a tag in the range[0, seq_stop)(whereseq_stopis set byInputFinished) so we could not use it to order results.As specified in ARROW-12710, the
KernelStateof the string concat agg kernel will need to include ordering criteria so thatmerge(move(state1), &state0)can be guaranteed equivalent tomerge(move(state0), &state1). Furthermore,mergecannot actually concatenate anything because if we happened to firstmerge(move(state0), &state3)we'd have no way to insertstate1, state2in the middle later. Actual concatenation would have to wait forfinalize.Those ordering criteria could be synthesized from (for example) fragment/batch index information, but the presence of
O(N)state in a scalar agg kernel's State is suspect to me and I'm not sure it's a great match for ScalarAggregateKernel.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah thanks, sorry for the misunderstanding (I need to stop thinking only about datasets).
I suppose it only makes sense to talk about 'order' when directly downstream from a scan or explicit sort, then. And any aggregates that have O(N) state might properly belong as their own ExecNode.