diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index 53c2a9c5da7..6c3c71161ef 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -419,20 +419,7 @@ impl Max { impl AggregateExpr for Max { fn data_type(&self, input_schema: &Schema) -> Result { - match self.expr.data_type(input_schema)? { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - Ok(DataType::Int64) - } - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { - Ok(DataType::UInt64) - } - DataType::Float32 => Ok(DataType::Float32), - DataType::Float64 => Ok(DataType::Float64), - other => Err(ExecutionError::General(format!( - "MAX does not support {:?}", - other - ))), - } + self.expr.data_type(input_schema) } fn evaluate_input(&self, batch: &RecordBatch) -> Result { @@ -449,13 +436,13 @@ impl AggregateExpr for Max { } macro_rules! max_accumulate { - ($SELF:ident, $VALUE:expr, $ARRAY_TYPE:ident, $SCALAR_VARIANT:ident, $TY:ty) => {{ + ($SELF:ident, $VALUE:expr, $ARRAY_TYPE:ident, $SCALAR_VARIANT:ident) => {{ $SELF.max = match $SELF.max { Some(ScalarValue::$SCALAR_VARIANT(n)) => { - if n > ($VALUE as $TY) { + if n > ($VALUE) { Some(ScalarValue::$SCALAR_VARIANT(n)) } else { - Some(ScalarValue::$SCALAR_VARIANT($VALUE as $TY)) + Some(ScalarValue::$SCALAR_VARIANT($VALUE)) } } Some(_) => { @@ -463,7 +450,7 @@ macro_rules! max_accumulate { "Unexpected ScalarValue variant".to_string(), )) } - None => Some(ScalarValue::$SCALAR_VARIANT($VALUE as $TY)), + None => Some(ScalarValue::$SCALAR_VARIANT($VALUE)), }; }}; } @@ -477,34 +464,34 @@ impl Accumulator for MaxAccumulator { if let Some(value) = value { match value { ScalarValue::Int8(value) => { - max_accumulate!(self, value, Int8Array, Int64, i64); + max_accumulate!(self, value, Int8Array, Int8); } ScalarValue::Int16(value) => { - max_accumulate!(self, value, Int16Array, Int64, i64) + max_accumulate!(self, value, Int16Array, Int16) } ScalarValue::Int32(value) => { - max_accumulate!(self, value, Int32Array, Int64, i64) + max_accumulate!(self, value, Int32Array, Int32) } ScalarValue::Int64(value) => { - max_accumulate!(self, value, Int64Array, Int64, i64) + max_accumulate!(self, value, Int64Array, Int64) } ScalarValue::UInt8(value) => { - max_accumulate!(self, value, UInt8Array, UInt64, u64) + max_accumulate!(self, value, UInt8Array, UInt8) } ScalarValue::UInt16(value) => { - max_accumulate!(self, value, UInt16Array, UInt64, u64) + max_accumulate!(self, value, UInt16Array, UInt16) } ScalarValue::UInt32(value) => { - max_accumulate!(self, value, UInt32Array, UInt64, u64) + max_accumulate!(self, value, UInt32Array, UInt32) } ScalarValue::UInt64(value) => { - max_accumulate!(self, value, UInt64Array, UInt64, u64) + max_accumulate!(self, value, UInt64Array, UInt64) } ScalarValue::Float32(value) => { - max_accumulate!(self, value, Float32Array, Float32, f32) + max_accumulate!(self, value, Float32Array, Float32) } ScalarValue::Float64(value) => { - max_accumulate!(self, value, Float64Array, Float64, f64) + max_accumulate!(self, value, Float64Array, Float64) } other => { return Err(ExecutionError::General(format!( @@ -616,20 +603,7 @@ impl Min { impl AggregateExpr for Min { fn data_type(&self, input_schema: &Schema) -> Result { - match self.expr.data_type(input_schema)? { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - Ok(DataType::Int64) - } - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { - Ok(DataType::UInt64) - } - DataType::Float32 => Ok(DataType::Float32), - DataType::Float64 => Ok(DataType::Float64), - other => Err(ExecutionError::General(format!( - "MIN does not support {:?}", - other - ))), - } + self.expr.data_type(input_schema) } fn evaluate_input(&self, batch: &RecordBatch) -> Result { @@ -646,13 +620,13 @@ impl AggregateExpr for Min { } macro_rules! min_accumulate { - ($SELF:ident, $VALUE:expr, $ARRAY_TYPE:ident, $SCALAR_VARIANT:ident, $TY:ty) => {{ + ($SELF:ident, $VALUE:expr, $ARRAY_TYPE:ident, $SCALAR_VARIANT:ident) => {{ $SELF.min = match $SELF.min { Some(ScalarValue::$SCALAR_VARIANT(n)) => { - if n < ($VALUE as $TY) { + if n < ($VALUE) { Some(ScalarValue::$SCALAR_VARIANT(n)) } else { - Some(ScalarValue::$SCALAR_VARIANT($VALUE as $TY)) + Some(ScalarValue::$SCALAR_VARIANT($VALUE)) } } Some(_) => { @@ -660,7 +634,7 @@ macro_rules! min_accumulate { "Unexpected ScalarValue variant".to_string(), )) } - None => Some(ScalarValue::$SCALAR_VARIANT($VALUE as $TY)), + None => Some(ScalarValue::$SCALAR_VARIANT($VALUE)), }; }}; } @@ -674,34 +648,34 @@ impl Accumulator for MinAccumulator { if let Some(value) = value { match value { ScalarValue::Int8(value) => { - min_accumulate!(self, value, Int8Array, Int64, i64); + min_accumulate!(self, value, Int8Array, Int8); } ScalarValue::Int16(value) => { - min_accumulate!(self, value, Int16Array, Int64, i64) + min_accumulate!(self, value, Int16Array, Int16) } ScalarValue::Int32(value) => { - min_accumulate!(self, value, Int32Array, Int64, i64) + min_accumulate!(self, value, Int32Array, Int32) } ScalarValue::Int64(value) => { - min_accumulate!(self, value, Int64Array, Int64, i64) + min_accumulate!(self, value, Int64Array, Int64) } ScalarValue::UInt8(value) => { - min_accumulate!(self, value, UInt8Array, UInt64, u64) + min_accumulate!(self, value, UInt8Array, UInt8) } ScalarValue::UInt16(value) => { - min_accumulate!(self, value, UInt16Array, UInt64, u64) + min_accumulate!(self, value, UInt16Array, UInt16) } ScalarValue::UInt32(value) => { - min_accumulate!(self, value, UInt32Array, UInt64, u64) + min_accumulate!(self, value, UInt32Array, UInt32) } ScalarValue::UInt64(value) => { - min_accumulate!(self, value, UInt64Array, UInt64, u64) + min_accumulate!(self, value, UInt64Array, UInt64) } ScalarValue::Float32(value) => { - min_accumulate!(self, value, Float32Array, Float32, f32) + min_accumulate!(self, value, Float32Array, Float32) } ScalarValue::Float64(value) => { - min_accumulate!(self, value, Float64Array, Float64, f64) + min_accumulate!(self, value, Float64Array, Float64) } other => { return Err(ExecutionError::General(format!( @@ -1481,7 +1455,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let max = max(col("a")); - assert_eq!(DataType::Int64, max.data_type(&schema)?); + assert_eq!(DataType::Int32, max.data_type(&schema)?); // after the aggr expression is applied, the schema changes to: let schema = Schema::new(vec![ @@ -1490,7 +1464,7 @@ mod tests { ]); let combiner = max.create_reducer("Max(a)"); - assert_eq!(DataType::Int64, combiner.data_type(&schema)?); + assert_eq!(DataType::Int32, combiner.data_type(&schema)?); Ok(()) } @@ -1500,7 +1474,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let min = min(col("a")); - assert_eq!(DataType::Int64, min.data_type(&schema)?); + assert_eq!(DataType::Int32, min.data_type(&schema)?); // after the aggr expression is applied, the schema changes to: let schema = Schema::new(vec![ @@ -1508,7 +1482,7 @@ mod tests { Field::new("MIN(a)", min.data_type(&schema)?, false), ]); let combiner = min.create_reducer("MIN(a)"); - assert_eq!(DataType::Int64, combiner.data_type(&schema)?); + assert_eq!(DataType::Int32, combiner.data_type(&schema)?); Ok(()) } @@ -1562,7 +1536,7 @@ mod tests { let a = Int32Array::from(vec![1, 2, 3, 4, 5]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - assert_eq!(do_max(&batch)?, Some(ScalarValue::Int64(5))); + assert_eq!(do_max(&batch)?, Some(ScalarValue::Int32(5))); Ok(()) } @@ -1574,7 +1548,7 @@ mod tests { let a = Int32Array::from(vec![1, 2, 3, 4, 5]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - assert_eq!(do_min(&batch)?, Some(ScalarValue::Int64(1))); + assert_eq!(do_min(&batch)?, Some(ScalarValue::Int32(1))); Ok(()) } @@ -1610,7 +1584,7 @@ mod tests { let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - assert_eq!(do_max(&batch)?, Some(ScalarValue::Int64(5))); + assert_eq!(do_max(&batch)?, Some(ScalarValue::Int32(5))); Ok(()) } @@ -1622,7 +1596,7 @@ mod tests { let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - assert_eq!(do_min(&batch)?, Some(ScalarValue::Int64(1))); + assert_eq!(do_min(&batch)?, Some(ScalarValue::Int32(1))); Ok(()) } @@ -1706,7 +1680,7 @@ mod tests { let a = UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - assert_eq!(do_max(&batch)?, Some(ScalarValue::UInt64(5_u64))); + assert_eq!(do_max(&batch)?, Some(ScalarValue::UInt32(5_u32))); Ok(()) } @@ -1718,7 +1692,7 @@ mod tests { let a = UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - assert_eq!(do_min(&batch)?, Some(ScalarValue::UInt64(1_u64))); + assert_eq!(do_min(&batch)?, Some(ScalarValue::UInt32(1_u32))); Ok(()) } diff --git a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs index 99f1161f4bb..0b20c8df571 100644 --- a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs @@ -27,8 +27,9 @@ use crate::execution::physical_plan::{ }; use arrow::array::{ - ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, - StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + ArrayBuilder, ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, + Int64Array, Int8Array, StringArray, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, }; use arrow::array::{ Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, @@ -166,64 +167,13 @@ impl Partition for HashAggregatePartition { } } -/// Create array from `key` attribute in map entry (representing a grouping scalar value) -macro_rules! group_array_from_map_entries { - ($BUILDER:ident, $TY:ident, $MAP:expr, $COL_INDEX:expr) => {{ - let mut builder = $BUILDER::new($MAP.len()); - let mut err = false; - for k in $MAP.keys() { - match k[$COL_INDEX] { - GroupByScalar::$TY(n) => builder.append_value(n).unwrap(), - _ => err = true, - } - } - if err { - Err(ExecutionError::ExecutionError( - "unexpected type when creating grouping array from aggregate map" - .to_string(), - )) - } else { - Ok(Arc::new(builder.finish()) as ArrayRef) - } - }}; -} - -/// Create array from `value` attribute in map entry (representing an aggregate scalar -/// value) -macro_rules! aggr_array_from_map_entries { - ($BUILDER:ident, $TY:ident, $TY2:ty, $MAP:expr, $COL_INDEX:expr) => {{ - let mut builder = $BUILDER::new($MAP.len()); - let mut err = false; - for v in $MAP.values() { - match v[$COL_INDEX] - .as_ref() - .borrow() - .get_value() - .map_err(ExecutionError::into_arrow_external_error)? - { - Some(ScalarValue::$TY(n)) => builder.append_value(n as $TY2).unwrap(), - None => builder.append_null().unwrap(), - _ => err = true, - } - } - if err { - Err(ExecutionError::ExecutionError( - "unexpected type when creating aggregate array from aggregate map" - .to_string(), - )) - } else { - Ok(Arc::new(builder.finish()) as ArrayRef) - } - }}; -} - /// Create array from single accumulator value -macro_rules! aggr_array_from_accumulator { - ($BUILDER:ident, $TY:ident, $TY2:ty, $VALUE:expr) => {{ +macro_rules! accum_val { + ($BUILDER:ident, $SCALAR_TY:ident, $VALUE:expr) => {{ let mut builder = $BUILDER::new(1); match $VALUE { - Some(ScalarValue::$TY(n)) => { - builder.append_value(n as $TY2)?; + Some(ScalarValue::$SCALAR_TY(n)) => { + builder.append_value(n)?; Ok(Arc::new(builder.finish()) as ArrayRef) } None => { @@ -231,19 +181,13 @@ macro_rules! aggr_array_from_accumulator { Ok(Arc::new(builder.finish()) as ArrayRef) } _ => Err(ExecutionError::ExecutionError( - "unexpected type when creating aggregate array from aggregate map" + "unexpected type when creating aggregate array from no-group aggregate" .to_string(), )), } }}; } -#[derive(Debug)] -struct MapEntry { - k: Vec, - v: Vec>, -} - struct GroupedHashAggregateIterator { schema: SchemaRef, group_expr: Vec>, @@ -272,7 +216,7 @@ impl GroupedHashAggregateIterator { type AccumulatorSet = Vec>>; -macro_rules! update_accumulators { +macro_rules! update_accum { ($ARRAY:ident, $ARRAY_TY:ident, $SCALAR_TY:expr, $COL:expr, $ACCUM:expr) => {{ let primitive_array = $ARRAY.as_any().downcast_ref::<$ARRAY_TY>().unwrap(); @@ -336,7 +280,7 @@ impl RecordBatchReader for GroupedHashAggregateIterator { } // iterate over each row in the batch and create the accumulators for each grouping key - let mut accumulators: Vec> = + let mut accums: Vec> = Vec::with_capacity(batch.num_rows()); for row in 0..batch.num_rows() { @@ -345,7 +289,7 @@ impl RecordBatchReader for GroupedHashAggregateIterator { .map_err(ExecutionError::into_arrow_external_error)?; if let Some(accumulator_set) = map.get(&key) { - accumulators.push(accumulator_set.clone()); + accums.push(accumulator_set.clone()); } else { let accumulator_set: AccumulatorSet = self .aggr_expr @@ -356,7 +300,7 @@ impl RecordBatchReader for GroupedHashAggregateIterator { let accumulator_set = Rc::new(accumulator_set); map.insert(key.clone(), accumulator_set.clone()); - accumulators.push(accumulator_set); + accums.push(accumulator_set); } } @@ -366,75 +310,55 @@ impl RecordBatchReader for GroupedHashAggregateIterator { let array = &aggr_input_values[col]; match array.data_type() { - DataType::Int8 => update_accumulators!( - array, - Int8Array, - ScalarValue::Int8, - col, - accumulators - ), - DataType::Int16 => update_accumulators!( - array, - Int16Array, - ScalarValue::Int16, - col, - accumulators - ), - DataType::Int32 => update_accumulators!( - array, - Int32Array, - ScalarValue::Int32, - col, - accumulators - ), - DataType::Int64 => update_accumulators!( - array, - Int64Array, - ScalarValue::Int64, - col, - accumulators - ), - DataType::UInt8 => update_accumulators!( - array, - UInt8Array, - ScalarValue::UInt8, - col, - accumulators - ), - DataType::UInt16 => update_accumulators!( + DataType::Int8 => { + update_accum!(array, Int8Array, ScalarValue::Int8, col, accums) + } + DataType::Int16 => { + update_accum!(array, Int16Array, ScalarValue::Int16, col, accums) + } + DataType::Int32 => { + update_accum!(array, Int32Array, ScalarValue::Int32, col, accums) + } + DataType::Int64 => { + update_accum!(array, Int64Array, ScalarValue::Int64, col, accums) + } + DataType::UInt8 => { + update_accum!(array, UInt8Array, ScalarValue::UInt8, col, accums) + } + DataType::UInt16 => update_accum!( array, UInt16Array, ScalarValue::UInt16, col, - accumulators + accums ), - DataType::UInt32 => update_accumulators!( + DataType::UInt32 => update_accum!( array, UInt32Array, ScalarValue::UInt32, col, - accumulators + accums ), - DataType::UInt64 => update_accumulators!( + DataType::UInt64 => update_accum!( array, UInt64Array, ScalarValue::UInt64, col, - accumulators + accums ), - DataType::Float32 => update_accumulators!( + DataType::Float32 => update_accum!( array, Float32Array, ScalarValue::Float32, col, - accumulators + accums ), - DataType::Float64 => update_accumulators!( + DataType::Float64 => update_accum!( array, Float64Array, ScalarValue::Float64, col, - accumulators + accums ), other => { return Err(ExecutionError::ExecutionError(format!( @@ -446,108 +370,14 @@ impl RecordBatchReader for GroupedHashAggregateIterator { } } - let input_schema = input.schema(); - - // build the result arrays - let mut result_arrays: Vec = - Vec::with_capacity(self.group_expr.len() + self.aggr_expr.len()); - - // grouping values - for i in 0..self.group_expr.len() { - let array: Result = match self.group_expr[i] - .data_type(&input_schema) - .map_err(ExecutionError::into_arrow_external_error)? - { - DataType::UInt8 => { - group_array_from_map_entries!(UInt8Builder, UInt8, map, i) - } - DataType::UInt16 => { - group_array_from_map_entries!(UInt16Builder, UInt16, map, i) - } - DataType::UInt32 => { - group_array_from_map_entries!(UInt32Builder, UInt32, map, i) - } - DataType::UInt64 => { - group_array_from_map_entries!(UInt64Builder, UInt64, map, i) - } - DataType::Int8 => { - group_array_from_map_entries!(Int8Builder, Int8, map, i) - } - DataType::Int16 => { - group_array_from_map_entries!(Int16Builder, Int16, map, i) - } - DataType::Int32 => { - group_array_from_map_entries!(Int32Builder, Int32, map, i) - } - DataType::Int64 => { - group_array_from_map_entries!(Int64Builder, Int64, map, i) - } - DataType::Utf8 => { - let mut builder = StringBuilder::new(1); - for k in map.keys() { - match &k[i] { - GroupByScalar::Utf8(s) => builder.append_value(&s).unwrap(), - _ => { - return Err(ExecutionError::ExecutionError( - "Unexpected value for Utf8 group column".to_string(), - ) - .into_arrow_external_error()) - } - } - } - Ok(Arc::new(builder.finish()) as ArrayRef) - } - _ => Err(ExecutionError::ExecutionError( - "Unsupported group by expr".to_string(), - )), - }; - result_arrays.push(array.map_err(ExecutionError::into_arrow_external_error)?); - } + let batch = create_batch_from_map( + &map, + self.group_expr.len(), + self.aggr_expr.len(), + &self.schema, + ) + .map_err(ExecutionError::into_arrow_external_error)?; - // aggregate values - for i in 0..self.aggr_expr.len() { - let aggr_data_type = self.aggr_expr[i] - .data_type(&input_schema) - .map_err(ExecutionError::into_arrow_external_error)?; - let array = match aggr_data_type { - DataType::UInt8 => { - aggr_array_from_map_entries!(UInt64Builder, UInt8, u64, map, i) - } - DataType::UInt16 => { - aggr_array_from_map_entries!(UInt64Builder, UInt16, u64, map, i) - } - DataType::UInt32 => { - aggr_array_from_map_entries!(UInt64Builder, UInt32, u64, map, i) - } - DataType::UInt64 => { - aggr_array_from_map_entries!(UInt64Builder, UInt64, u64, map, i) - } - DataType::Int8 => { - aggr_array_from_map_entries!(Int64Builder, Int8, i64, map, i) - } - DataType::Int16 => { - aggr_array_from_map_entries!(Int64Builder, Int16, i64, map, i) - } - DataType::Int32 => { - aggr_array_from_map_entries!(Int64Builder, Int32, i64, map, i) - } - DataType::Int64 => { - aggr_array_from_map_entries!(Int64Builder, Int64, i64, map, i) - } - DataType::Float32 => { - aggr_array_from_map_entries!(Float32Builder, Float32, f32, map, i) - } - DataType::Float64 => { - aggr_array_from_map_entries!(Float64Builder, Float64, f64, map, i) - } - _ => Err(ExecutionError::ExecutionError( - "Unsupported aggregate expr".to_string(), - )), - }; - result_arrays.push(array.map_err(ExecutionError::into_arrow_external_error)?); - } - - let batch = RecordBatch::try_new(self.schema.clone(), result_arrays)?; Ok(Some(batch)) } } @@ -636,36 +466,16 @@ impl RecordBatchReader for HashAggregateIterator { .get_value() .map_err(ExecutionError::into_arrow_external_error)?; let array = match aggr_data_type { - DataType::UInt8 => { - aggr_array_from_accumulator!(UInt64Builder, UInt8, u64, value) - } - DataType::UInt16 => { - aggr_array_from_accumulator!(UInt64Builder, UInt16, u64, value) - } - DataType::UInt32 => { - aggr_array_from_accumulator!(UInt64Builder, UInt32, u64, value) - } - DataType::UInt64 => { - aggr_array_from_accumulator!(UInt64Builder, UInt64, u64, value) - } - DataType::Int8 => { - aggr_array_from_accumulator!(Int64Builder, Int8, i64, value) - } - DataType::Int16 => { - aggr_array_from_accumulator!(Int64Builder, Int16, i64, value) - } - DataType::Int32 => { - aggr_array_from_accumulator!(Int64Builder, Int32, i64, value) - } - DataType::Int64 => { - aggr_array_from_accumulator!(Int64Builder, Int64, i64, value) - } - DataType::Float32 => { - aggr_array_from_accumulator!(Float32Builder, Float32, f32, value) - } - DataType::Float64 => { - aggr_array_from_accumulator!(Float64Builder, Float64, f64, value) - } + DataType::UInt8 => accum_val!(UInt8Builder, UInt8, value), + DataType::UInt16 => accum_val!(UInt16Builder, UInt16, value), + DataType::UInt32 => accum_val!(UInt32Builder, UInt32, value), + DataType::UInt64 => accum_val!(UInt64Builder, UInt64, value), + DataType::Int8 => accum_val!(Int8Builder, Int8, value), + DataType::Int16 => accum_val!(Int16Builder, Int16, value), + DataType::Int32 => accum_val!(Int32Builder, Int32, value), + DataType::Int64 => accum_val!(Int64Builder, Int64, value), + DataType::Float32 => accum_val!(Float32Builder, Float32, value), + DataType::Float64 => accum_val!(Float64Builder, Float64, value), _ => Err(ExecutionError::ExecutionError( "Unsupported aggregate expr".to_string(), )), @@ -678,6 +488,141 @@ impl RecordBatchReader for HashAggregateIterator { } } +/// Append a grouping expression value to a builder +macro_rules! group_val { + ($BUILDER:expr, $BUILDER_TY:ident, $VALUE:expr) => {{ + let builder = $BUILDER + .downcast_mut::<$BUILDER_TY>() + .expect("failed to downcast group value builder to expected type"); + builder.append_value($VALUE)?; + }}; +} + +/// Append an aggregate expression value to a builder +macro_rules! aggr_val { + ($BUILDER:expr, $BUILDER_TY:ident, $VALUE:expr, $SCALAR_TY:ident) => {{ + let builder = $BUILDER + .downcast_mut::<$BUILDER_TY>() + .expect("failed to downcast aggregate value builder to expected type"); + match $VALUE { + Some(ScalarValue::$SCALAR_TY(n)) => builder.append_value(n)?, + None => builder.append_null()?, + Some(other) => { + return Err(ExecutionError::General(format!( + "Unexpected data type {:?} for aggregate value", + other + ))) + } + } + }}; +} + +/// Create a RecordBatch representing the accumulated results in a map +fn create_batch_from_map( + map: &FnvHashMap, Rc>, + num_group_expr: usize, + num_aggr_expr: usize, + output_schema: &Schema, +) -> Result { + // create builders based on the output schema data types + let output_types: Vec<&DataType> = output_schema + .fields() + .iter() + .map(|f| f.data_type()) + .collect(); + let mut builders: Vec> = vec![]; + for data_type in &output_types { + let builder: Box = match data_type { + DataType::Int8 => Box::new(Int8Builder::new(map.len())), + DataType::Int16 => Box::new(Int16Builder::new(map.len())), + DataType::Int32 => Box::new(Int32Builder::new(map.len())), + DataType::Int64 => Box::new(Int64Builder::new(map.len())), + DataType::UInt8 => Box::new(UInt8Builder::new(map.len())), + DataType::UInt16 => Box::new(UInt16Builder::new(map.len())), + DataType::UInt32 => Box::new(UInt32Builder::new(map.len())), + DataType::UInt64 => Box::new(UInt64Builder::new(map.len())), + DataType::Float32 => Box::new(Float32Builder::new(map.len())), + DataType::Float64 => Box::new(Float64Builder::new(map.len())), + DataType::Utf8 => Box::new(StringBuilder::new(map.len())), + _ => { + return Err(ExecutionError::ExecutionError( + "Unsupported data type in final aggregate result".to_string(), + )) + } + }; + builders.push(builder); + } + + // iterate over the map + for (k, v) in map.iter() { + // add group values to builders + for i in 0..num_group_expr { + let builder = builders[i].as_any_mut(); + match &k[i] { + GroupByScalar::Int8(n) => group_val!(builder, Int8Builder, *n), + GroupByScalar::Int16(n) => group_val!(builder, Int16Builder, *n), + GroupByScalar::Int32(n) => group_val!(builder, Int32Builder, *n), + GroupByScalar::Int64(n) => group_val!(builder, Int64Builder, *n), + GroupByScalar::UInt8(n) => group_val!(builder, UInt8Builder, *n), + GroupByScalar::UInt16(n) => group_val!(builder, UInt16Builder, *n), + GroupByScalar::UInt32(n) => group_val!(builder, UInt32Builder, *n), + GroupByScalar::UInt64(n) => group_val!(builder, UInt64Builder, *n), + GroupByScalar::Utf8(str) => group_val!(builder, StringBuilder, str), + } + } + + // add aggregate values to builders + for i in 0..num_aggr_expr { + let value = v[i].borrow().get_value()?; + let index = num_group_expr + i; + let builder = builders[index].as_any_mut(); + match output_types[index] { + DataType::Int8 => aggr_val!(builder, Int8Builder, value, Int8), + DataType::Int16 => aggr_val!(builder, Int16Builder, value, Int16), + DataType::Int32 => aggr_val!(builder, Int32Builder, value, Int32), + DataType::Int64 => aggr_val!(builder, Int64Builder, value, Int64), + DataType::UInt8 => aggr_val!(builder, UInt8Builder, value, UInt8), + DataType::UInt16 => aggr_val!(builder, UInt16Builder, value, UInt16), + DataType::UInt32 => aggr_val!(builder, UInt32Builder, value, UInt32), + DataType::UInt64 => aggr_val!(builder, UInt64Builder, value, UInt64), + DataType::Float32 => aggr_val!(builder, Float32Builder, value, Float32), + DataType::Float64 => aggr_val!(builder, Float64Builder, value, Float64), + // The aggr_val! macro doesn't work for ScalarValue::Utf8 because it contains + // String and the builder wants &str. In all other cases the scalar and builder + // types are the same. + DataType::Utf8 => { + let builder = builder + .downcast_mut::() + .expect("failed to downcast builder to expected type"); + match value { + Some(ScalarValue::Utf8(str)) => builder.append_value(&str)?, + None => builder.append_null()?, + Some(_) => { + return Err(ExecutionError::ExecutionError( + "Invalid value for accumulator".to_string(), + )) + } + } + } + _ => { + return Err(ExecutionError::ExecutionError( + "Unsupported aggregate data type".to_string(), + )) + } + }; + } + } + + let arrays: Vec = builders + .iter_mut() + .map(|builder| builder.finish()) + .collect(); + + let batch = RecordBatch::try_new(Arc::new(output_schema.to_owned()), arrays)?; + + Ok(batch) +} + /// Enumeration of types that can be used in a GROUP BY expression (all primitives except /// for floating point numerics) #[derive(Debug, PartialEq, Eq, Hash, Clone)] diff --git a/rust/datafusion/src/execution/physical_plan/merge.rs b/rust/datafusion/src/execution/physical_plan/merge.rs index 4db54c07187..dfb4f578c82 100644 --- a/rust/datafusion/src/execution/physical_plan/merge.rs +++ b/rust/datafusion/src/execution/physical_plan/merge.rs @@ -18,7 +18,7 @@ //! Defines the merge plan for executing partitions in parallel and then merging the results //! into a single partition -use crate::error::Result; +use crate::error::{ExecutionError, Result}; use crate::execution::physical_plan::common::RecordBatchIterator; use crate::execution::physical_plan::Partition; use crate::execution::physical_plan::{common, ExecutionPlan}; @@ -83,11 +83,14 @@ impl Partition for MergePartition { // combine the results from each thread let mut combined_results: Vec> = vec![]; for thread in threads { - let join = thread.join().expect("Failed to join thread"); - let result = join?; - result - .iter() - .for_each(|batch| combined_results.push(Arc::new(batch.clone()))); + match thread.join() { + Ok(join) => { + join? + .iter() + .for_each(|batch| combined_results.push(Arc::new(batch.clone()))); + } + Err(e) => return Err(ExecutionError::General(format!("{:?}", e))), + } } Ok(Arc::new(Mutex::new(RecordBatchIterator::new( diff --git a/testing b/testing index 535369d600a..f552c4dcd2a 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit 535369d600a58cbfe6d952777187561b4dacfcbd +Subproject commit f552c4dcd2ae3d14048abd20919748cce5276ade