Skip to content
242 changes: 242 additions & 0 deletions native/core/src/execution/merge_as_partial.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! MergeAsPartial wrapper for implementing Spark's PartialMerge aggregate mode.
//!
//! Spark's PartialMerge mode merges intermediate state buffers and outputs intermediate
//! state (not final values). DataFusion has no equivalent mode — `Partial` calls
//! `update_batch` and outputs state, while `Final` calls `merge_batch` and outputs
//! evaluated results.
//!
//! This wrapper bridges the gap: it operates under DataFusion's `Partial` mode (which
//! outputs state) but redirects `update_batch` calls to `merge_batch`, giving merge
//! semantics with state output.

use std::any::Any;
use std::fmt::Debug;
use std::hash::{Hash, Hasher};

use arrow::array::{ArrayRef, BooleanArray};
use arrow::datatypes::{DataType, FieldRef};
use datafusion::common::Result;
use datafusion::logical_expr::function::AccumulatorArgs;
use datafusion::logical_expr::function::StateFieldsArgs;
use datafusion::logical_expr::{
Accumulator, AggregateUDF, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF,
Signature, Volatility,
};
use datafusion::physical_expr::aggregate::AggregateFunctionExpr;
use datafusion::scalar::ScalarValue;

/// An AggregateUDF wrapper that gives merge semantics in Partial mode.
///
/// When DataFusion runs an AggregateExec in Partial mode, it calls `update_batch`
/// on each accumulator and outputs `state()`. This wrapper intercepts `update_batch`
/// and redirects it to `merge_batch` on the inner accumulator, effectively
/// implementing PartialMerge: merge inputs, output state.
///
/// We store the inner AggregateUDF (not the AggregateFunctionExpr) to avoid keeping
/// references to UnboundColumn expressions that would panic if evaluated.
#[derive(Debug)]
pub struct MergeAsPartialUDF {
/// The inner aggregate UDF, cloned from the original expression.
inner_udf: AggregateUDF,
/// Pre-computed return type from the original expression.
return_type: DataType,
/// Pre-computed state fields from the original expression.
cached_state_fields: Vec<FieldRef>,
/// Cached signature that accepts state field types.
signature: Signature,
/// Name for this wrapper.
name: String,
}

impl PartialEq for MergeAsPartialUDF {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}

impl Eq for MergeAsPartialUDF {}

impl Hash for MergeAsPartialUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
self.name.hash(state);
}
}

impl MergeAsPartialUDF {
pub fn new(inner_expr: &AggregateFunctionExpr) -> Result<Self> {
let name = format!("merge_as_partial_{}", inner_expr.name());
let return_type = inner_expr.field().data_type().clone();
let cached_state_fields = inner_expr.state_fields()?;

// Use a permissive signature since we accept state field types which
// vary per aggregate function.
let signature = Signature::variadic_any(Volatility::Immutable);

Ok(Self {
inner_udf: inner_expr.fun().clone(),
return_type,
cached_state_fields,
signature,
name,
})
}
}

impl AggregateUDFImpl for MergeAsPartialUDF {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
// In Partial mode, return_type isn't used for output schema (state_fields is).
// Return the inner function's return type for consistency.
Ok(self.return_type.clone())
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
// State fields must match the inner aggregate's state fields so that
// the output of this PartialMerge stage is compatible with subsequent
// Final or PartialMerge stages.
Ok(self.cached_state_fields.clone())
}

fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
// Create the inner accumulator using the provided args (which have the
// correct Column refs, not UnboundColumns).
let inner_acc = self.inner_udf.accumulator(args)?;
Ok(Box::new(MergeAsPartialAccumulator { inner: inner_acc }))
}

fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
self.inner_udf.groups_accumulator_supported(args)
}

fn create_groups_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
let inner_acc = self.inner_udf.create_groups_accumulator(args)?;
Ok(Box::new(MergeAsPartialGroupsAccumulator {
inner: inner_acc,
}))
}

fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::NotSupported
}

fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
ScalarValue::try_from(data_type)
}

fn is_descending(&self) -> Option<bool> {
None
}
}

/// Accumulator wrapper that redirects update_batch to merge_batch.
struct MergeAsPartialAccumulator {
inner: Box<dyn Accumulator>,
}

impl Debug for MergeAsPartialAccumulator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MergeAsPartialAccumulator").finish()
}
}

impl Accumulator for MergeAsPartialAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
// Redirect update to merge — this is the key trick.
self.inner.merge_batch(values)
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
self.inner.merge_batch(states)
}

fn evaluate(&mut self) -> Result<ScalarValue> {
self.inner.evaluate()
}

fn state(&mut self) -> Result<Vec<ScalarValue>> {
self.inner.state()
}

fn size(&self) -> usize {
self.inner.size()
}
}

/// GroupsAccumulator wrapper that redirects update_batch to merge_batch.
struct MergeAsPartialGroupsAccumulator {
inner: Box<dyn GroupsAccumulator>,
}

impl Debug for MergeAsPartialGroupsAccumulator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MergeAsPartialGroupsAccumulator").finish()
}
}

impl GroupsAccumulator for MergeAsPartialGroupsAccumulator {
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
// Redirect update to merge — this is the key trick.
self.inner
.merge_batch(values, group_indices, opt_filter, total_num_groups)
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
self.inner
.merge_batch(values, group_indices, opt_filter, total_num_groups)
}

fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
self.inner.evaluate(emit_to)
}

fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
self.inner.state(emit_to)
}

fn size(&self) -> usize {
self.inner.size()
}
}
1 change: 1 addition & 0 deletions native/core/src/execution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
pub mod columnar_to_row;
pub mod expressions;
pub mod jni_api;
pub(crate) mod merge_as_partial;
pub(crate) mod metrics;
pub mod operators;
pub(crate) mod planner;
Expand Down
81 changes: 76 additions & 5 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -967,19 +967,90 @@ impl PhysicalPlanner {
let group_by = PhysicalGroupBy::new_single(group_exprs?);
let schema = child.schema();

let mode = if agg.mode == 0 {
DFAggregateMode::Partial
} else {
DFAggregateMode::Final
let mode = match agg.mode {
0 => DFAggregateMode::Partial,
1 => DFAggregateMode::Final,
2 => DFAggregateMode::Partial, // PartialMerge: Partial + MergeAsPartial
other => {
return Err(ExecutionError::GeneralError(format!(
"Unsupported aggregate mode: {other}"
)))
}
};

// Check if any expression uses PartialMerge mode (2). When present,
// those expressions are wrapped with MergeAsPartial to get merge
// semantics inside a Partial-mode AggregateExec.
let has_partial_merge = agg.mode == 2 || agg.expr_modes.contains(&2);

let agg_exprs: PhyAggResult = agg
.agg_exprs
.iter()
.map(|expr| self.create_agg_expr(expr, Arc::clone(&schema)))
.collect();

let aggr_expr = agg_exprs?.into_iter().map(Arc::new).collect();
let aggr_expr: Vec<Arc<AggregateFunctionExpr>> = if has_partial_merge {
// Wrap PartialMerge expressions with MergeAsPartial.
// State fields in the child's output start at initial_input_buffer_offset.
let mut state_offset = agg.initial_input_buffer_offset as usize;
let per_expr_modes: Vec<i32> = if !agg.expr_modes.is_empty() {
agg.expr_modes.clone()
} else {
vec![agg.mode; agg.agg_exprs.len()]
};

agg_exprs?
.into_iter()
.enumerate()
.map(|(idx, expr)| {
if per_expr_modes[idx] == 2 {
// PartialMerge: wrap with MergeAsPartial
let state_fields = expr
.state_fields()
.map_err(|e| ExecutionError::GeneralError(e.to_string()))?;
let num_state_fields = state_fields.len();

let state_cols: Vec<Arc<dyn PhysicalExpr>> = (0..num_state_fields)
.map(|i| {
let col_idx = state_offset + i;
let field = schema.field(col_idx);
Arc::new(Column::new(field.name(), col_idx))
as Arc<dyn PhysicalExpr>
})
.collect();
state_offset += num_state_fields;

let merge_udf =
crate::execution::merge_as_partial::MergeAsPartialUDF::new(
&expr,
)
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))?;
let merge_udf_arc = Arc::new(
datafusion::logical_expr::AggregateUDF::new_from_impl(
merge_udf,
),
);

let merge_expr =
AggregateExprBuilder::new(merge_udf_arc, state_cols)
.schema(Arc::clone(&schema))
.alias(format!("col_{idx}"))
.with_ignore_nulls(expr.ignore_nulls())
.with_distinct(expr.is_distinct())
.build()
.map_err(|e| {
ExecutionError::DataFusionError(e.to_string())
})?;

Ok(Arc::new(merge_expr))
} else {
Ok(Arc::new(expr))
}
})
.collect::<Result<Vec<_>, ExecutionError>>()?
} else {
agg_exprs?.into_iter().map(Arc::new).collect()
};

// Build per-aggregate filter expressions from the FILTER (WHERE ...) clause.
// Filters are only present in Partial mode; Final/PartialMerge always get None.
Expand Down
8 changes: 8 additions & 0 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,13 @@ message HashAggregate {
repeated spark.spark_expression.AggExpr agg_exprs = 2;
repeated spark.spark_expression.Expr result_exprs = 3;
AggregateMode mode = 5;
// Per-expression modes for mixed-mode aggregates (e.g., PartialMerge + Partial).
// When set, each entry corresponds to agg_exprs at the same index.
// When empty, all expressions use the `mode` field.
repeated AggregateMode expr_modes = 6;
// Offset in the child's output where aggregate buffer attributes start.
// Used by PartialMerge to locate state fields in the input.
int32 initial_input_buffer_offset = 7;
}

message Limit {
Expand Down Expand Up @@ -319,6 +326,7 @@ message ParquetWriter {
enum AggregateMode {
Partial = 0;
Final = 1;
PartialMerge = 2;
}

message Expand {
Expand Down
Loading
Loading