diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 078e38ceb09d2..93d8d154d611a 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{ArrowDictionaryKeyType, DataType, Field}; +use datafusion_common::cast::as_dictionary_array; use std::any::Any; use std::fmt::Debug; use std::sync::Arc; @@ -30,8 +31,7 @@ use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; -type DistinctScalarValues = ScalarValue; - +type ValueSet = HashSet; /// Expression for a COUNT(DISTINCT) aggregation. #[derive(Debug)] pub struct DistinctCount { @@ -85,10 +85,59 @@ impl AggregateExpr for DistinctCount { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(DistinctCountAccumulator { - values: HashSet::default(), - state_data_type: self.state_data_type.clone(), - })) + use arrow::datatypes; + use datatypes::DataType::*; + + Ok(match &self.state_data_type { + Dictionary(key, val) if key.is_dictionary_key_type() => { + let val_type = *val.clone(); + match **key { + Int8 => Box::new( + CountDistinctDictAccumulator::::new( + val_type, + ), + ), + Int16 => Box::new( + CountDistinctDictAccumulator::::new( + val_type, + ), + ), + Int32 => Box::new( + CountDistinctDictAccumulator::::new( + val_type, + ), + ), + Int64 => Box::new( + CountDistinctDictAccumulator::::new( + val_type, + ), + ), + UInt8 => Box::new( + CountDistinctDictAccumulator::::new( + val_type, + ), + ), + UInt16 => Box::new(CountDistinctDictAccumulator::< + datatypes::UInt16Type, + >::new(val_type)), + UInt32 => Box::new(CountDistinctDictAccumulator::< + datatypes::UInt32Type, + >::new(val_type)), + UInt64 => Box::new(CountDistinctDictAccumulator::< + datatypes::UInt64Type, + >::new(val_type)), + _ => { + return Err(DataFusionError::Internal( + "Dict key has invalid datatype".to_string(), + )) + } + } + } + _ => Box::new(DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: self.state_data_type.clone(), + }), + }) } fn name(&self) -> &str { @@ -96,53 +145,66 @@ impl AggregateExpr for DistinctCount { } } -#[derive(Debug)] -struct DistinctCountAccumulator { - values: HashSet, - state_data_type: DataType, +// calculating the size of values hashset for fixed length values, +// taking first batch size * number of batches. +// This method is faster than full_size(), however it is not suitable for variable length +// values like strings or complex types +fn values_fixed_size(values: &ValueSet) -> usize { + (std::mem::size_of::() * values.capacity()) + + values + .iter() + .next() + .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .unwrap_or(0) +} +// calculates the size as accurate as possible, call to this method is expensive +// but necessary to correctly account for variable length strings +fn values_full_size(values: &ValueSet) -> usize { + (std::mem::size_of::() * values.capacity()) + + values + .iter() + .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .sum::() } -impl DistinctCountAccumulator { - // calculating the size for fixed length values, taking first batch size * number of batches - // This method is faster than .full_size(), however it is not suitable for variable length values like strings or complex types - fn fixed_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .next() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) - .unwrap_or(0) - + std::mem::size_of::() - } - - // calculates the size as accurate as possible, call to this method is expensive - fn full_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) - .sum::() - + std::mem::size_of::() +// helper func that takes accumulator state and merges it into a ValuesSet +fn merge_values(values: &mut ValueSet, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + let arr = &states[0]; + for index in 0..arr.len() { + let scalar = ScalarValue::try_from_array(arr, index)?; + if let ScalarValue::List(Some(scalar), _) = scalar { + for val in scalar.iter() { + if !val.is_null() { + values.insert(val.clone()); + } + } + } else { + return Err(DataFusionError::Internal( + "Unexpected accumulator state".into(), + )); + } } + Ok(()) +} + +// helper that converts value hashset into state vector +fn values_to_state(values: &ValueSet, datatype: &DataType) -> Result> { + let scalars = values.iter().cloned().collect::>(); + Ok(vec![ScalarValue::new_list(Some(scalars), datatype.clone())]) +} + +#[derive(Debug)] +struct DistinctCountAccumulator { + values: ValueSet, + state_data_type: DataType, } impl Accumulator for DistinctCountAccumulator { fn state(&self) -> Result> { - let mut cols_out = - ScalarValue::new_list(Some(Vec::new()), self.state_data_type.clone()); - self.values - .iter() - .enumerate() - .for_each(|(_, distinct_values)| { - if let ScalarValue::List(Some(ref mut v), _) = cols_out { - v.push(distinct_values.clone()); - } - }); - Ok(vec![cols_out]) + values_to_state(&self.values, &self.state_data_type) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { if values.is_empty() { @@ -158,26 +220,83 @@ impl Accumulator for DistinctCountAccumulator { }) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); + merge_values(&mut self.values, states) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) + } + + fn size(&self) -> usize { + let values_size = match &self.state_data_type { + DataType::Boolean | DataType::Null => values_fixed_size(&self.values), + d if d.is_primitive() => values_fixed_size(&self.values), + _ => values_full_size(&self.values), + }; + std::mem::size_of_val(self) + values_size + std::mem::size_of::() + } +} +/// Special case accumulator for counting distinct values in a dict +struct CountDistinctDictAccumulator +where + K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync, +{ + /// `K` is required when casting to dict array + _dt: core::marker::PhantomData, + values_datatype: DataType, + values: ValueSet, +} + +impl std::fmt::Debug for CountDistinctDictAccumulator +where + K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CountDistinctDictAccumulator") + .field("values", &self.values) + .field("values_datatype", &self.values_datatype) + .finish() + } +} +impl + CountDistinctDictAccumulator +{ + fn new(values_datatype: DataType) -> Self { + Self { + _dt: core::marker::PhantomData, + values: Default::default(), + values_datatype, } - let arr = &states[0]; - (0..arr.len()).try_for_each(|index| { - let scalar = ScalarValue::try_from_array(arr, index)?; + } +} +impl Accumulator for CountDistinctDictAccumulator +where + K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync, +{ + fn state(&self) -> Result> { + values_to_state(&self.values, &self.values_datatype) + } - if let ScalarValue::List(Some(scalar), _) = scalar { - scalar.iter().for_each(|scalar| { - if !ScalarValue::is_null(scalar) { - self.values.insert(scalar.clone()); - } - }); - } else { - return Err(DataFusionError::Internal( - "Unexpected accumulator state".into(), - )); + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + let arr = as_dictionary_array::(&values[0])?; + let nvalues = arr.values().len(); + // map keys to whether their corresponding value has been seen or not + let mut seen_map = vec![false; nvalues]; + for idx in arr.keys_iter().flatten() { + if !seen_map[idx] { + let scalar = ScalarValue::try_from_array(arr.values(), idx)?; + self.values.insert(scalar); + seen_map[idx] = true; } - Ok(()) - }) + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + merge_values(&mut self.values, states) } fn evaluate(&self) -> Result { @@ -185,11 +304,12 @@ impl Accumulator for DistinctCountAccumulator { } fn size(&self) -> usize { - match &self.state_data_type { - DataType::Boolean | DataType::Null => self.fixed_size(), - d if d.is_primitive() => self.fixed_size(), - _ => self.full_size(), - } + let values_size = match &self.values_datatype { + DataType::Boolean | DataType::Null => values_fixed_size(&self.values), + d if d.is_primitive() => values_fixed_size(&self.values), + _ => values_full_size(&self.values), + }; + std::mem::size_of_val(self) + values_size + std::mem::size_of::() } } @@ -199,10 +319,11 @@ mod tests { use super::*; use arrow::array::{ - ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + ArrayRef, BooleanArray, DictionaryArray, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, StringArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, }; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Int8Type}; macro_rules! state_to_vec { ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{ @@ -346,8 +467,6 @@ mod tests { let mut state_vec = state_to_vec!(&states[0], $DATA_TYPE, $PRIM_TYPE).unwrap(); - - dbg!(&state_vec); state_vec.sort_by(|a, b| match (a, b) { (Some(lhs), Some(rhs)) => lhs.total_cmp(rhs), _ => a.partial_cmp(b).unwrap(), @@ -577,4 +696,76 @@ mod tests { assert_eq!(result, ScalarValue::Int64(Some(2))); Ok(()) } + + #[test] + fn count_distinct_dict_update() -> Result<()> { + let values = StringArray::from_iter_values(["a", "b", "c"]); + // value "b" is never used + let keys = + Int8Array::from_iter(vec![Some(0), Some(0), Some(0), Some(0), None, Some(2)]); + let arrays = + vec![ + Arc::new(DictionaryArray::::try_new(&keys, &values).unwrap()) + as ArrayRef, + ]; + let agg = DistinctCount::new( + arrays[0].data_type().clone(), + Arc::new(NoOp::new()), + String::from("__col_name__"), + ); + let mut accum = agg.create_accumulator()?; + accum.update_batch(&arrays)?; + // should evaluate to 2 since "b" never seen + assert_eq!(accum.evaluate()?, ScalarValue::Int64(Some(2))); + // now update with a new batch that does use "b" (and non-normalized values) + let values = StringArray::from_iter_values(["b", "a", "c", "d"]); + let keys = Int8Array::from_iter(vec![Some(0), Some(0), None]); + let arrays = + vec![ + Arc::new(DictionaryArray::::try_new(&keys, &values).unwrap()) + as ArrayRef, + ]; + accum.update_batch(&arrays)?; + assert_eq!(accum.evaluate()?, ScalarValue::Int64(Some(3))); + Ok(()) + } + + #[test] + fn count_distinct_dict_merge() -> Result<()> { + let values = StringArray::from_iter_values(["a", "b", "c"]); + let keys = Int8Array::from_iter(vec![Some(0), Some(0), None]); + let arrays = + vec![ + Arc::new(DictionaryArray::::try_new(&keys, &values).unwrap()) + as ArrayRef, + ]; + let agg = DistinctCount::new( + arrays[0].data_type().clone(), + Arc::new(NoOp::new()), + String::from("__col_name__"), + ); + // create accum with 1 value seen + let mut accum = agg.create_accumulator()?; + accum.update_batch(&arrays)?; + assert_eq!(accum.evaluate()?, ScalarValue::Int64(Some(1))); + // create accum with state that has seen "a" and "b" but not "c" + let values = StringArray::from_iter_values(["c", "b", "a"]); + let keys = Int8Array::from_iter(vec![Some(2), Some(1), None]); + let arrays = + vec![ + Arc::new(DictionaryArray::::try_new(&keys, &values).unwrap()) + as ArrayRef, + ]; + let mut accum2 = agg.create_accumulator()?; + accum2.update_batch(&arrays)?; + let states = accum2 + .state()? + .into_iter() + .map(|v| v.to_array()) + .collect::>(); + // after merging the accumulator should have seen 2 vals + accum.merge_batch(&states)?; + assert_eq!(accum.evaluate()?, ScalarValue::Int64(Some(2))); + Ok(()) + } }