Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e62c77c
Progress commit
nealrichardson Jul 14, 2021
051074e
Apply Ben's patch and sketch R side
nealrichardson Jul 14, 2021
41c0826
More R; try to get C++ to compile
nealrichardson Jul 15, 2021
1947e15
const
nealrichardson Jul 15, 2021
1bc0789
improve keepalive pattern
bkietz Jul 15, 2021
b5b41a3
Compiles but segfaults
nealrichardson Jul 15, 2021
f34c932
revert keepalives
bkietz Jul 15, 2021
683dbcc
Actually run the tests
nealrichardson Jul 15, 2021
a1f676d
Restore docs
nealrichardson Jul 15, 2021
100a178
Restore try()
nealrichardson Jul 15, 2021
d3190a2
Use FieldsInExpression to project in Scan
nealrichardson Jul 15, 2021
1107cd2
repair merge error
bkietz Jul 26, 2021
2576f59
Basic exercise of GroupByNode
nealrichardson Jul 26, 2021
1b423a0
fix ExecBatch slicing
bkietz Jul 27, 2021
1816f2c
Adapt result to meet dplyr expectation
nealrichardson Jul 27, 2021
776e1f5
Remove some tests for features not implemented for datasets since tha…
nealrichardson Jul 29, 2021
58f4930
Refactor agg function definition and registry and add any/all
nealrichardson Jul 29, 2021
aeb0bf8
Add jira references
nealrichardson Jul 30, 2021
a7f5cde
Use filter node to actually filter
nealrichardson Jul 30, 2021
eab89e8
Format and re-doc
nealrichardson Aug 3, 2021
da43f5c
Remove feature flag
nealrichardson Aug 4, 2021
56df2d3
handle .groups argument
nealrichardson Aug 4, 2021
f5d5d30
Prevent na.rm = FALSE aggregation because it's wrong
nealrichardson Aug 4, 2021
b906253
Merge branch 'scalar-aggregate-node' of github.com:nealrichardson/arr…
nealrichardson Aug 4, 2021
6922815
Suppress warning and style files
nealrichardson Aug 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/exec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ ExecBatch ExecBatch::Slice(int64_t offset, int64_t length) const {
if (value.is_scalar()) continue;
value = value.array()->Slice(offset, length);
}
out.length = length;
out.length = std::min(length, this->length - offset);
return out;
}

Expand Down
35 changes: 26 additions & 9 deletions cpp/src/arrow/compute/exec/exec_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -719,11 +719,13 @@ struct ScalarAggregateNode : ExecNode {
ScalarAggregateNode(ExecNode* input, std::string label,
std::shared_ptr<Schema> output_schema,
std::vector<const ScalarAggregateKernel*> kernels,
std::vector<int> argument_indices,
std::vector<std::vector<std::unique_ptr<KernelState>>> states)
: ExecNode(input->plan(), std::move(label), {input}, {"target"},
/*output_schema=*/std::move(output_schema),
/*num_outputs=*/1),
kernels_(std::move(kernels)),
argument_indices_(std::move(argument_indices)),
states_(std::move(states)) {}

const char* kind_name() override { return "ScalarAggregateNode"; }
Expand All @@ -733,7 +735,7 @@ struct ScalarAggregateNode : ExecNode {
KernelContext batch_ctx{plan()->exec_context()};
batch_ctx.SetState(states_[i][thread_index].get());

ExecBatch single_column_batch{{batch.values[i]}, batch.length};
ExecBatch single_column_batch{{batch[argument_indices_[i]]}, batch.length};
RETURN_NOT_OK(kernels_[i]->consume(&batch_ctx, single_column_batch));
}
return Status::OK();
Expand Down Expand Up @@ -807,7 +809,8 @@ struct ScalarAggregateNode : ExecNode {
}

Future<> finished_ = Future<>::MakeFinished();
std::vector<const ScalarAggregateKernel*> kernels_;
const std::vector<const ScalarAggregateKernel*> kernels_;
const std::vector<int> argument_indices_;

std::vector<std::vector<std::unique_ptr<KernelState>>> states_;

Expand All @@ -816,20 +819,34 @@ struct ScalarAggregateNode : ExecNode {
};

Result<ExecNode*> MakeScalarAggregateNode(ExecNode* input, std::string label,
std::vector<internal::Aggregate> aggregates) {
if (input->output_schema()->num_fields() != static_cast<int>(aggregates.size())) {
return Status::Invalid("Provided ", aggregates.size(),
" aggregates, expected one for each field of ",
input->output_schema()->ToString());
std::vector<internal::Aggregate> aggregates,
std::vector<FieldRef> arguments,
std::vector<std::string> out_field_names) {
if (aggregates.size() != arguments.size()) {
return Status::Invalid("Provided ", aggregates.size(), " aggregates but ",
arguments.size(), " arguments.");
}

if (aggregates.size() != out_field_names.size()) {
return Status::Invalid("Provided ", aggregates.size(), " aggregates but ",
out_field_names.size(), " field names for the output.");
}

auto exec_ctx = input->plan()->exec_context();

std::vector<const ScalarAggregateKernel*> kernels(aggregates.size());
std::vector<std::vector<std::unique_ptr<KernelState>>> states(kernels.size());
FieldVector fields(kernels.size());
std::vector<int> argument_indices(kernels.size());

for (size_t i = 0; i < kernels.size(); ++i) {
if (!arguments[i].IsName()) {
return Status::NotImplemented("Non name field refs");
}
ARROW_ASSIGN_OR_RAISE(auto match,
arguments[i].FindOneOrNone(*input->output_schema()));
argument_indices[i] = match[0];

ARROW_ASSIGN_OR_RAISE(auto function,
exec_ctx->func_registry()->GetFunction(aggregates[i].function));

Expand Down Expand Up @@ -862,12 +879,12 @@ Result<ExecNode*> MakeScalarAggregateNode(ExecNode* input, std::string label,
ARROW_ASSIGN_OR_RAISE(
auto descr, kernels[i]->signature->out_type().Resolve(&kernel_ctx, {in_type}));

fields[i] = field(aggregates[i].function, std::move(descr.type));
fields[i] = field(std::move(out_field_names[i]), std::move(descr.type));
}

return input->plan()->EmplaceNode<ScalarAggregateNode>(
input, std::move(label), schema(std::move(fields)), std::move(kernels),
std::move(states));
std::move(argument_indices), std::move(states));
}

namespace internal {
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/arrow/compute/exec/exec_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,9 @@ Result<ExecNode*> MakeProjectNode(ExecNode* input, std::string label,

ARROW_EXPORT
Result<ExecNode*> MakeScalarAggregateNode(ExecNode* input, std::string label,
std::vector<internal::Aggregate> aggregates);
std::vector<internal::Aggregate> aggregates,
std::vector<FieldRef> arguments,
std::vector<std::string> out_field_names);

/// \brief Make a node which groups input rows based on key fields and computes
/// aggregates for each group
Expand Down
11 changes: 7 additions & 4 deletions cpp/src/arrow/compute/exec/plan_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -531,9 +531,11 @@ TEST(ExecPlanExecution, SourceScalarAggSink) {
MakeTestSourceNode(plan.get(), "source", basic_data,
/*parallel=*/false, /*slow=*/false));

ASSERT_OK_AND_ASSIGN(auto scalar_agg,
MakeScalarAggregateNode(source, "scalar_agg",
{{"sum", nullptr}, {"any", nullptr}}));
ASSERT_OK_AND_ASSIGN(
auto scalar_agg,
MakeScalarAggregateNode(source, "scalar_agg", {{"sum", nullptr}, {"any", nullptr}},
/*targets=*/{"i32", "bool"},
/*out_field_names=*/{"sum(i32)", "any(bool)"}));

auto sink_gen = MakeSinkNode(scalar_agg, "sink");

Expand Down Expand Up @@ -565,7 +567,8 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) {
ASSERT_OK_AND_ASSIGN(
auto scalar_agg,
MakeScalarAggregateNode(source, "scalar_agg",
{{"count", nullptr}, {"sum", nullptr}, {"mean", nullptr}}));
{{"count", nullptr}, {"sum", nullptr}, {"mean", nullptr}},
{"a", "b", "c"}, {"sum a", "sum b", "sum c"}));

auto sink_gen = MakeSinkNode(scalar_agg, "sink");

Expand Down
9 changes: 5 additions & 4 deletions cpp/src/arrow/dataset/scanner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -816,14 +816,15 @@ Result<int64_t> AsyncScanner::CountRows() {
ARROW_ASSIGN_OR_RAISE(auto scan,
MakeScanNode(plan.get(), std::move(fragment_gen), options));

ARROW_ASSIGN_OR_RAISE(
auto get_selection,
compute::MakeProjectNode(scan, "get_selection", {options->filter}));
ARROW_ASSIGN_OR_RAISE(auto get_selection,
compute::MakeProjectNode(scan, "get_selection", {options->filter},
{"selection_mask"}));

ARROW_ASSIGN_OR_RAISE(
auto sum_selection,
compute::MakeScalarAggregateNode(get_selection, "sum_selection",
{compute::internal::Aggregate{"sum", nullptr}}));
{compute::internal::Aggregate{"sum", nullptr}},
{"selection_mask"}, {"sum"}));

AsyncGenerator<util::optional<compute::ExecBatch>> sink_gen =
compute::MakeSinkNode(sum_selection, "sink");
Expand Down
14 changes: 8 additions & 6 deletions cpp/src/arrow/dataset/scanner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1471,14 +1471,16 @@ TEST(ScanNode, MinimalScalarAggEndToEnd) {
ASSERT_OK_AND_ASSIGN(
compute::ExecNode * sum,
compute::MakeScalarAggregateNode(project, "scalar_agg",
{compute::internal::Aggregate{"sum", nullptr}}));
{compute::internal::Aggregate{"sum", nullptr}},
{a_times_2.ToString()}, {"a*2 sum"}));

// finally, pipe the project node into a sink node
auto sink_gen = compute::MakeSinkNode(sum, "sink");

// translate sink_gen (async) to sink_reader (sync)
std::shared_ptr<RecordBatchReader> sink_reader = compute::MakeGeneratorReader(
schema({field("sum", int64())}), std::move(sink_gen), exec_context.memory_pool());
std::shared_ptr<RecordBatchReader> sink_reader =
compute::MakeGeneratorReader(schema({field("a*2 sum", int64())}),
std::move(sink_gen), exec_context.memory_pool());

// start the ExecPlan
ASSERT_OK(plan->StartProducing());
Expand All @@ -1489,9 +1491,9 @@ TEST(ScanNode, MinimalScalarAggEndToEnd) {
// wait 1s for completion
ASSERT_TRUE(plan->finished().Wait(/*seconds=*/1)) << "ExecPlan didn't finish within 1s";

auto expected = TableFromJSON(schema({field("sum", int64())}), {
R"([
{"sum": 4}
auto expected = TableFromJSON(schema({field("a*2 sum", int64())}), {
R"([
{"a*2 sum": 4}
])"});
AssertTablesEqual(*expected, *collected, /*same_chunk_layout=*/false);
}
Expand Down
1 change: 1 addition & 0 deletions r/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ Collate:
'metadata.R'
'parquet.R'
'python.R'
'query-engine.R'
'record-batch-reader.R'
'record-batch-writer.R'
'reexports-bit64.R'
Expand Down
Loading