diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index dee253f44ac33..b53f7c15e3aac 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1766,6 +1766,25 @@ mod tests { "+-----+-------------+", ]; assert_batches_sorted_eq!(expected, &results); + + // Now, use dict as an aggregate + let results = plan_and_collect( + &mut ctx, + "SELECT val, count(distinct dict) FROM t GROUP BY val", + ) + .await + .expect("ran plan correctly"); + + let expected = vec![ + "+-----+----------------------+", + "| val | COUNT(DISTINCT dict) |", + "+-----+----------------------+", + "| 1 | 2 |", + "| 2 | 2 |", + "| 4 | 1 |", + "+-----+----------------------+", + ]; + assert_batches_sorted_eq!(expected, &results); } run_test_case::().await; diff --git a/datafusion/src/physical_plan/distinct_expressions.rs b/datafusion/src/physical_plan/distinct_expressions.rs index 1c93b5a104d09..8167541c3e1a5 100644 --- a/datafusion/src/physical_plan/distinct_expressions.rs +++ b/datafusion/src/physical_plan/distinct_expressions.rs @@ -47,8 +47,8 @@ pub struct DistinctCount { name: String, /// The DataType for the final count data_type: DataType, - /// The DataType for each input argument - input_data_types: Vec, + /// The DataType used to hold the state for each input + state_data_types: Vec, /// The input arguments exprs: Vec>, } @@ -61,8 +61,10 @@ impl DistinctCount { name: String, data_type: DataType, ) -> Self { + let state_data_types = input_data_types.into_iter().map(state_type).collect(); + Self { - input_data_types, + state_data_types, exprs, name, data_type, @@ -70,6 +72,15 @@ impl DistinctCount { } } +/// return the type to use to accumulate state for the specified input type +fn state_type(data_type: DataType) -> DataType { + match data_type { + // when aggregating dictionary values, use the underlying value type + DataType::Dictionary(_key_type, value_type) => *value_type, + t => t, + } +} + impl AggregateExpr for DistinctCount { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -82,12 +93,16 @@ impl AggregateExpr for DistinctCount { fn state_fields(&self) -> Result> { Ok(self - .input_data_types + .state_data_types .iter() - .map(|data_type| { + .map(|state_data_type| { Field::new( &format_state_name(&self.name, "count distinct"), - DataType::List(Box::new(Field::new("item", data_type.clone(), true))), + DataType::List(Box::new(Field::new( + "item", + state_data_type.clone(), + true, + ))), false, ) }) @@ -101,7 +116,7 @@ impl AggregateExpr for DistinctCount { fn create_accumulator(&self) -> Result> { Ok(Box::new(DistinctCountAccumulator { values: HashSet::default(), - data_types: self.input_data_types.clone(), + state_data_types: self.state_data_types.clone(), count_data_type: self.data_type.clone(), })) } @@ -110,7 +125,7 @@ impl AggregateExpr for DistinctCount { #[derive(Debug)] struct DistinctCountAccumulator { values: HashSet, - data_types: Vec, + state_data_types: Vec, count_data_type: DataType, } @@ -156,9 +171,11 @@ impl Accumulator for DistinctCountAccumulator { fn state(&self) -> Result> { let mut cols_out = self - .data_types + .state_data_types .iter() - .map(|data_type| ScalarValue::List(Some(Vec::new()), data_type.clone())) + .map(|state_data_type| { + ScalarValue::List(Some(Vec::new()), state_data_type.clone()) + }) .collect::>(); let mut cols_vec = cols_out diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 11f0946c91ff6..a8f6f0c35f00e 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -274,6 +274,7 @@ pub trait AggregateExpr: Send + Sync + Debug { /// Returns the aggregate expression as [`Any`](std::any::Any) so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; + /// the field of the final result of this aggregation. fn field(&self) -> Result; diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 6f03194f45423..eb9b3095cd626 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -19,10 +19,13 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; -use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; +use arrow::datatypes::{ArrowDictionaryKeyType, DataType, Field, IntervalUnit, TimeUnit}; use arrow::{ array::*, - datatypes::{ArrowNativeType, Float32Type, TimestampNanosecondType}, + datatypes::{ + ArrowNativeType, Float32Type, Int16Type, Int32Type, Int64Type, Int8Type, + TimestampNanosecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + }, }; use arrow::{ array::{ @@ -444,14 +447,53 @@ impl ScalarValue { DataType::Timestamp(TimeUnit::Nanosecond, _) => { typed_cast!(array, index, TimestampNanosecondArray, TimestampNanosecond) } + DataType::Dictionary(index_type, _) => match **index_type { + DataType::Int8 => Self::try_from_dict_array::(array, index)?, + DataType::Int16 => Self::try_from_dict_array::(array, index)?, + DataType::Int32 => Self::try_from_dict_array::(array, index)?, + DataType::Int64 => Self::try_from_dict_array::(array, index)?, + DataType::UInt8 => Self::try_from_dict_array::(array, index)?, + DataType::UInt16 => { + Self::try_from_dict_array::(array, index)? + } + DataType::UInt32 => { + Self::try_from_dict_array::(array, index)? + } + DataType::UInt64 => { + Self::try_from_dict_array::(array, index)? + } + _ => { + return Err(DataFusionError::Internal(format!( + "Index type not supported while creating scalar from dictionary: {}", + array.data_type(), + ))) + } + }, other => { return Err(DataFusionError::NotImplemented(format!( - "Can't create a scalar of array of type \"{:?}\"", + "Can't create a scalar from array of type \"{:?}\"", other ))) } }) } + + fn try_from_dict_array( + array: &ArrayRef, + index: usize, + ) -> Result { + let dict_array = array.as_any().downcast_ref::>().unwrap(); + + // look up the index in the values dictionary + let keys_col = dict_array.keys_array(); + let values_index = keys_col.value(index).to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert index to usize in dictionary of type creating group by value {:?}", + keys_col.data_type() + )) + })?; + Self::try_from_array(&dict_array.values(), values_index) + } } impl From for ScalarValue {