From 74f5042c81b76043e47abe0c1a3e026823089941 Mon Sep 17 00:00:00 2001 From: Jay Miller <3744812+jaylmiller@users.noreply.github.com> Date: Sat, 11 Mar 2023 16:43:59 -0500 Subject: [PATCH 1/7] Add new count distinct accumulator when running on dict arrays --- .../src/aggregate/count_distinct.rs | 296 +++++++++++++++++- 1 file changed, 286 insertions(+), 10 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 078e38ceb09d2..cc2f8407966f1 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{ArrowDictionaryKeyType, DataType, Field}; +use datafusion_common::cast::as_dictionary_array; use std::any::Any; use std::fmt::Debug; use std::sync::Arc; @@ -85,10 +86,47 @@ impl AggregateExpr for DistinctCount { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(DistinctCountAccumulator { - values: HashSet::default(), - state_data_type: self.state_data_type.clone(), - })) + use arrow::datatypes; + use datatypes::DataType::*; + + Ok(match &self.state_data_type { + Dictionary(key, _) if key.is_dictionary_key_type() => { + match **key { + Int8 => Box::new( + CountDistinctDictAccumulator::::new(), + ), + Int16 => Box::new( + CountDistinctDictAccumulator::::new(), + ), + Int32 => Box::new( + CountDistinctDictAccumulator::::new(), + ), + Int64 => Box::new( + CountDistinctDictAccumulator::::new(), + ), + UInt8 => Box::new( + CountDistinctDictAccumulator::::new(), + ), + UInt16 => Box::new(CountDistinctDictAccumulator::< + datatypes::UInt16Type, + >::new()), + UInt32 => Box::new(CountDistinctDictAccumulator::< + datatypes::UInt32Type, + >::new()), + UInt64 => Box::new(CountDistinctDictAccumulator::< + datatypes::UInt64Type, + >::new()), + _ => { + // just checked that datatype is a valid dict key type + unreachable!() + } + } + } + _ => Box::new(DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: self.state_data_type.clone(), + }), + }) } fn name(&self) -> &str { @@ -192,6 +230,143 @@ impl Accumulator for DistinctCountAccumulator { } } } +/// Special case accumulator for counting distinct values in a dict +struct CountDistinctDictAccumulator +where + K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync, +{ + /// `K` is required when casting to dict array + _dt: core::marker::PhantomData, + /// laziliy initialized state that holds a boolean for each index. + /// the bool at each index indicates whether the value for that index has been seen yet. + state: Option>, +} + +impl std::fmt::Debug for CountDistinctDictAccumulator +where + K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CountDistinctDictAccumulator") + .field("state", &self.state) + .finish() + } +} +impl + CountDistinctDictAccumulator +{ + fn new() -> Self { + Self { + _dt: core::marker::PhantomData, + state: None, + } + } +} +impl Accumulator for CountDistinctDictAccumulator +where + K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync, +{ + fn state(&self) -> Result> { + if let Some(state) = &self.state { + let bools = state + .iter() + .map(|b| ScalarValue::Boolean(Some(*b))) + .collect(); + Ok(vec![ScalarValue::List( + Some(bools), + Box::new(Field::new("item", DataType::Boolean, false)), + )]) + } else { + // empty state + Ok(vec![]) + } + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + let arr = as_dictionary_array::(&values[0])?; + let nvalues = arr.values().len(); + if let Some(state) = &self.state { + if state.len() != nvalues { + return Err(DataFusionError::Internal( + "Accumulator update_batch got invalid value".to_string(), + )); + } + } else { + // init state + self.state = Some((0..nvalues).map(|_| false).collect()); + } + for key in arr.keys_iter() { + if let Some(idx) = key { + self.state + .as_mut() + // state always will have been initialized at this point + .unwrap()[idx] = true; + } + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + + let arr = &states[0]; + (0..arr.len()).try_for_each(|index| { + let scalar = ScalarValue::try_from_array(arr, index)?; + + if let ScalarValue::List(Some(scalar), _) = scalar { + if self.state.is_none() { + self.state = Some((0..scalar.len()).map(|_| false).collect()); + } else if scalar.len() != self.state.as_ref().unwrap().len() { + return Err(DataFusionError::Internal( + "accumulator merged invalid state".into(), + )); + } + for (idx, val) in scalar.iter().enumerate() { + match val { + ScalarValue::Boolean(Some(b)) => { + if *b { + self.state.as_mut().unwrap()[idx] = true; + } + } + _ => { + return Err(DataFusionError::Internal( + "Unexpected accumulator state".into(), + )); + } + } + } + } else { + return Err(DataFusionError::Internal( + "Unexpected accumulator state".into(), + )); + } + Ok(()) + }) + } + + fn evaluate(&self) -> Result { + if let Some(state) = &self.state { + let num_seen = state.iter().filter(|v| **v).count(); + Ok(ScalarValue::Int64(Some(num_seen as i64))) + } else { + Ok(ScalarValue::Int64(Some(0))) + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + self + .state + .as_ref() + .map(|state| std::mem::size_of::() * state.capacity()) + .unwrap_or(0) + } +} #[cfg(test)] mod tests { @@ -199,10 +374,11 @@ mod tests { use super::*; use arrow::array::{ - ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + ArrayRef, BooleanArray, DictionaryArray, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, StringArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, }; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Int8Type}; macro_rules! state_to_vec { ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{ @@ -346,8 +522,6 @@ mod tests { let mut state_vec = state_to_vec!(&states[0], $DATA_TYPE, $PRIM_TYPE).unwrap(); - - dbg!(&state_vec); state_vec.sort_by(|a, b| match (a, b) { (Some(lhs), Some(rhs)) => lhs.total_cmp(rhs), _ => a.partial_cmp(b).unwrap(), @@ -577,4 +751,106 @@ mod tests { assert_eq!(result, ScalarValue::Int64(Some(2))); Ok(()) } + + #[test] + fn count_distinct_dict_update() -> Result<()> { + let values = StringArray::from_iter_values(["a", "b", "c"]); + // value "b" is never used + let keys = + Int8Array::from_iter(vec![Some(0), Some(0), Some(0), Some(0), None, Some(2)]); + let arrays = + vec![ + Arc::new(DictionaryArray::::try_new(&keys, &values).unwrap()) + as ArrayRef, + ]; + let agg = DistinctCount::new( + arrays[0].data_type().clone(), + Arc::new(NoOp::new()), + String::from("__col_name__"), + ); + let mut accum = agg.create_accumulator()?; + accum.update_batch(&arrays)?; + // should evaluate to 2 since "b" never seen + assert_eq!(accum.evaluate()?, ScalarValue::Int64(Some(2))); + // now update with a new batch that does use "b" + let values = StringArray::from_iter_values(["a", "b", "c"]); + let keys = Int8Array::from_iter(vec![Some(1), Some(1), None]); + let arrays = + vec![ + Arc::new(DictionaryArray::::try_new(&keys, &values).unwrap()) + as ArrayRef, + ]; + accum.update_batch(&arrays)?; + assert_eq!(accum.evaluate()?, ScalarValue::Int64(Some(3))); + Ok(()) + } + + #[test] + fn count_distinct_dict_merge() -> Result<()> { + let values = StringArray::from_iter_values(["a", "b", "c"]); + let keys = Int8Array::from_iter(vec![Some(0), Some(0), None]); + let arrays = + vec![ + Arc::new(DictionaryArray::::try_new(&keys, &values).unwrap()) + as ArrayRef, + ]; + let agg = DistinctCount::new( + arrays[0].data_type().clone(), + Arc::new(NoOp::new()), + String::from("__col_name__"), + ); + // create accum with 1 value seen + let mut accum = agg.create_accumulator()?; + accum.update_batch(&arrays)?; + assert_eq!(accum.evaluate()?, ScalarValue::Int64(Some(1))); + // create accum with state that has seen "a" and "b" but not "c" + let values = StringArray::from_iter_values(["a", "b", "c"]); + let keys = Int8Array::from_iter(vec![Some(0), Some(1), None]); + let arrays = + vec![ + Arc::new(DictionaryArray::::try_new(&keys, &values).unwrap()) + as ArrayRef, + ]; + let mut accum2 = agg.create_accumulator()?; + accum2.update_batch(&arrays)?; + let states = accum2 + .state()? + .into_iter() + .map(|v| v.to_array()) + .collect::>(); + // after merging the accumulator should have seen 2 vals + accum.merge_batch(&states)?; + assert_eq!(accum.evaluate()?, ScalarValue::Int64(Some(2))); + Ok(()) + } + + #[test] + fn count_distinct_dict_merge_inits_state() -> Result<()> { + let values = StringArray::from_iter_values(["a", "b", "c"]); + let keys = Int8Array::from_iter(vec![Some(0), Some(1), None]); + let arrays = + vec![ + Arc::new(DictionaryArray::::try_new(&keys, &values).unwrap()) + as ArrayRef, + ]; + let agg = DistinctCount::new( + arrays[0].data_type().clone(), + Arc::new(NoOp::new()), + String::from("__col_name__"), + ); + // create accum to get a state from + let mut accum = agg.create_accumulator()?; + accum.update_batch(&arrays)?; + let states = accum + .state()? + .into_iter() + .map(|v| v.to_array()) + .collect::>(); + // create accum that hasnt been initialized + // the merge_batch should initialize its state + let mut accum2 = agg.create_accumulator()?; + accum2.merge_batch(&states)?; + assert_eq!(accum2.evaluate()?, ScalarValue::Int64(Some(2))); + Ok(()) + } } From e5f42c38829e77b9bb3e1ff215b5a7e1ae76617c Mon Sep 17 00:00:00 2001 From: Jay Miller <3744812+jaylmiller@users.noreply.github.com> Date: Sat, 11 Mar 2023 17:24:49 -0500 Subject: [PATCH 2/7] clippy --- datafusion/physical-expr/src/aggregate/count_distinct.rs | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index cc2f8407966f1..7ea1ae8e751ae 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -298,13 +298,8 @@ where // init state self.state = Some((0..nvalues).map(|_| false).collect()); } - for key in arr.keys_iter() { - if let Some(idx) = key { - self.state - .as_mut() - // state always will have been initialized at this point - .unwrap()[idx] = true; - } + for idx in arr.keys_iter().flatten() { + self.state.as_mut().unwrap()[idx] = true; } Ok(()) } From dc479b6a2d84708fa9c4aee2d35dae27a08e6e5c Mon Sep 17 00:00:00 2001 From: Jay Miller <3744812+jaylmiller@users.noreply.github.com> Date: Sun, 12 Mar 2023 13:41:40 -0400 Subject: [PATCH 3/7] dont assume normalized dicts. move shared logic into fns --- .../src/aggregate/count_distinct.rs | 295 +++++++----------- 1 file changed, 108 insertions(+), 187 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 7ea1ae8e751ae..a4632a9614eea 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -90,32 +90,43 @@ impl AggregateExpr for DistinctCount { use datatypes::DataType::*; Ok(match &self.state_data_type { - Dictionary(key, _) if key.is_dictionary_key_type() => { + Dictionary(key, val) if key.is_dictionary_key_type() => { + let val_type = *val.clone(); match **key { Int8 => Box::new( - CountDistinctDictAccumulator::::new(), + CountDistinctDictAccumulator::::new( + val_type, + ), ), Int16 => Box::new( - CountDistinctDictAccumulator::::new(), + CountDistinctDictAccumulator::::new( + val_type, + ), ), Int32 => Box::new( - CountDistinctDictAccumulator::::new(), + CountDistinctDictAccumulator::::new( + val_type, + ), ), Int64 => Box::new( - CountDistinctDictAccumulator::::new(), + CountDistinctDictAccumulator::::new( + val_type, + ), ), UInt8 => Box::new( - CountDistinctDictAccumulator::::new(), + CountDistinctDictAccumulator::::new( + val_type, + ), ), UInt16 => Box::new(CountDistinctDictAccumulator::< datatypes::UInt16Type, - >::new()), + >::new(val_type)), UInt32 => Box::new(CountDistinctDictAccumulator::< datatypes::UInt32Type, - >::new()), + >::new(val_type)), UInt64 => Box::new(CountDistinctDictAccumulator::< datatypes::UInt64Type, - >::new()), + >::new(val_type)), _ => { // just checked that datatype is a valid dict key type unreachable!() @@ -133,6 +144,56 @@ impl AggregateExpr for DistinctCount { &self.name } } +type ValueSet = HashSet; +// calculating the size of values hashset for fixed length values, +// taking first batch size * number of batches. +// This method is faster than full_size(), however it is not suitable for variable length +// values like strings or complex types +fn values_fixed_size(values: &ValueSet) -> usize { + (std::mem::size_of::() * values.capacity()) + + values + .iter() + .next() + .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .unwrap_or(0) +} +// calculates the size as accurate as possible, call to this method is expensive +fn values_full_size(values: &ValueSet) -> usize { + (std::mem::size_of::() * values.capacity()) + + values + .iter() + .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .sum::() +} + +// helper func that takes accumulator state and merges it into a ValuesSet +fn merge_values(values: &mut ValueSet, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + let arr = &states[0]; + for index in 0..arr.len() { + let scalar = ScalarValue::try_from_array(arr, index)?; + if let ScalarValue::List(Some(scalar), _) = scalar { + for val in scalar.iter() { + if !val.is_null() { + values.insert(val.clone()); + } + } + } else { + return Err(DataFusionError::Internal( + "Unexpected accumulator state".into(), + )); + } + } + Ok(()) +} + +// helper that converts value hashset into state vector +fn values_to_state(values: &ValueSet, datatype: &DataType) -> Result> { + let scalars = values.iter().cloned().collect::>(); + Ok(vec![ScalarValue::new_list(Some(scalars), datatype.clone())]) +} #[derive(Debug)] struct DistinctCountAccumulator { @@ -140,47 +201,9 @@ struct DistinctCountAccumulator { state_data_type: DataType, } -impl DistinctCountAccumulator { - // calculating the size for fixed length values, taking first batch size * number of batches - // This method is faster than .full_size(), however it is not suitable for variable length values like strings or complex types - fn fixed_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .next() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) - .unwrap_or(0) - + std::mem::size_of::() - } - - // calculates the size as accurate as possible, call to this method is expensive - fn full_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) - .sum::() - + std::mem::size_of::() - } -} - impl Accumulator for DistinctCountAccumulator { fn state(&self) -> Result> { - let mut cols_out = - ScalarValue::new_list(Some(Vec::new()), self.state_data_type.clone()); - self.values - .iter() - .enumerate() - .for_each(|(_, distinct_values)| { - if let ScalarValue::List(Some(ref mut v), _) = cols_out { - v.push(distinct_values.clone()); - } - }); - Ok(vec![cols_out]) + values_to_state(&self.values, &self.state_data_type) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { if values.is_empty() { @@ -196,26 +219,7 @@ impl Accumulator for DistinctCountAccumulator { }) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - let arr = &states[0]; - (0..arr.len()).try_for_each(|index| { - let scalar = ScalarValue::try_from_array(arr, index)?; - - if let ScalarValue::List(Some(scalar), _) = scalar { - scalar.iter().for_each(|scalar| { - if !ScalarValue::is_null(scalar) { - self.values.insert(scalar.clone()); - } - }); - } else { - return Err(DataFusionError::Internal( - "Unexpected accumulator state".into(), - )); - } - Ok(()) - }) + merge_values(&mut self.values, states) } fn evaluate(&self) -> Result { @@ -223,11 +227,12 @@ impl Accumulator for DistinctCountAccumulator { } fn size(&self) -> usize { - match &self.state_data_type { - DataType::Boolean | DataType::Null => self.fixed_size(), - d if d.is_primitive() => self.fixed_size(), - _ => self.full_size(), - } + let values_size = match &self.state_data_type { + DataType::Boolean | DataType::Null => values_fixed_size(&self.values), + d if d.is_primitive() => values_fixed_size(&self.values), + _ => values_full_size(&self.values), + }; + std::mem::size_of_val(self) + values_size + std::mem::size_of::() } } /// Special case accumulator for counting distinct values in a dict @@ -237,9 +242,8 @@ where { /// `K` is required when casting to dict array _dt: core::marker::PhantomData, - /// laziliy initialized state that holds a boolean for each index. - /// the bool at each index indicates whether the value for that index has been seen yet. - state: Option>, + values_datatype: DataType, + values: HashSet, } impl std::fmt::Debug for CountDistinctDictAccumulator @@ -248,17 +252,19 @@ where { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("CountDistinctDictAccumulator") - .field("state", &self.state) + .field("values", &self.values) + .field("values_datatype", &self.values_datatype) .finish() } } impl CountDistinctDictAccumulator { - fn new() -> Self { + fn new(values_datatype: DataType) -> Self { Self { _dt: core::marker::PhantomData, - state: None, + values: Default::default(), + values_datatype, } } } @@ -267,19 +273,7 @@ where K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync, { fn state(&self) -> Result> { - if let Some(state) = &self.state { - let bools = state - .iter() - .map(|b| ScalarValue::Boolean(Some(*b))) - .collect(); - Ok(vec![ScalarValue::List( - Some(bools), - Box::new(Field::new("item", DataType::Boolean, false)), - )]) - } else { - // empty state - Ok(vec![]) - } + values_to_state(&self.values, &self.values_datatype) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { @@ -288,78 +282,35 @@ where } let arr = as_dictionary_array::(&values[0])?; let nvalues = arr.values().len(); - if let Some(state) = &self.state { - if state.len() != nvalues { - return Err(DataFusionError::Internal( - "Accumulator update_batch got invalid value".to_string(), - )); - } - } else { - // init state - self.state = Some((0..nvalues).map(|_| false).collect()); - } + // map keys to whether their corresponding value has been seen or not + let mut seen_map = (0..nvalues).map(|_| false).collect::>(); for idx in arr.keys_iter().flatten() { - self.state.as_mut().unwrap()[idx] = true; + seen_map[idx] = true; + } + for (idx, seen) in seen_map.into_iter().enumerate() { + if seen { + let scalar = ScalarValue::try_from_array(arr.values(), idx)?; + self.values.insert(scalar); + } } Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - let arr = &states[0]; - (0..arr.len()).try_for_each(|index| { - let scalar = ScalarValue::try_from_array(arr, index)?; - - if let ScalarValue::List(Some(scalar), _) = scalar { - if self.state.is_none() { - self.state = Some((0..scalar.len()).map(|_| false).collect()); - } else if scalar.len() != self.state.as_ref().unwrap().len() { - return Err(DataFusionError::Internal( - "accumulator merged invalid state".into(), - )); - } - for (idx, val) in scalar.iter().enumerate() { - match val { - ScalarValue::Boolean(Some(b)) => { - if *b { - self.state.as_mut().unwrap()[idx] = true; - } - } - _ => { - return Err(DataFusionError::Internal( - "Unexpected accumulator state".into(), - )); - } - } - } - } else { - return Err(DataFusionError::Internal( - "Unexpected accumulator state".into(), - )); - } - Ok(()) - }) + merge_values(&mut self.values, states) } fn evaluate(&self) -> Result { - if let Some(state) = &self.state { - let num_seen = state.iter().filter(|v| **v).count(); - Ok(ScalarValue::Int64(Some(num_seen as i64))) - } else { - Ok(ScalarValue::Int64(Some(0))) - } + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) } fn size(&self) -> usize { - std::mem::size_of_val(self) - + self - .state - .as_ref() - .map(|state| std::mem::size_of::() * state.capacity()) - .unwrap_or(0) + let values_size = match &self.values_datatype { + DataType::Boolean | DataType::Null => values_fixed_size(&self.values), + d if d.is_primitive() => values_fixed_size(&self.values), + _ => values_full_size(&self.values), + }; + std::mem::size_of_val(self) + values_size + std::mem::size_of::() } } @@ -767,9 +718,9 @@ mod tests { accum.update_batch(&arrays)?; // should evaluate to 2 since "b" never seen assert_eq!(accum.evaluate()?, ScalarValue::Int64(Some(2))); - // now update with a new batch that does use "b" - let values = StringArray::from_iter_values(["a", "b", "c"]); - let keys = Int8Array::from_iter(vec![Some(1), Some(1), None]); + // now update with a new batch that does use "b" (and non-normalized values) + let values = StringArray::from_iter_values(["b", "a", "c", "d"]); + let keys = Int8Array::from_iter(vec![Some(0), Some(0), None]); let arrays = vec![ Arc::new(DictionaryArray::::try_new(&keys, &values).unwrap()) @@ -799,8 +750,8 @@ mod tests { accum.update_batch(&arrays)?; assert_eq!(accum.evaluate()?, ScalarValue::Int64(Some(1))); // create accum with state that has seen "a" and "b" but not "c" - let values = StringArray::from_iter_values(["a", "b", "c"]); - let keys = Int8Array::from_iter(vec![Some(0), Some(1), None]); + let values = StringArray::from_iter_values(["c", "b", "a"]); + let keys = Int8Array::from_iter(vec![Some(2), Some(1), None]); let arrays = vec![ Arc::new(DictionaryArray::::try_new(&keys, &values).unwrap()) @@ -818,34 +769,4 @@ mod tests { assert_eq!(accum.evaluate()?, ScalarValue::Int64(Some(2))); Ok(()) } - - #[test] - fn count_distinct_dict_merge_inits_state() -> Result<()> { - let values = StringArray::from_iter_values(["a", "b", "c"]); - let keys = Int8Array::from_iter(vec![Some(0), Some(1), None]); - let arrays = - vec![ - Arc::new(DictionaryArray::::try_new(&keys, &values).unwrap()) - as ArrayRef, - ]; - let agg = DistinctCount::new( - arrays[0].data_type().clone(), - Arc::new(NoOp::new()), - String::from("__col_name__"), - ); - // create accum to get a state from - let mut accum = agg.create_accumulator()?; - accum.update_batch(&arrays)?; - let states = accum - .state()? - .into_iter() - .map(|v| v.to_array()) - .collect::>(); - // create accum that hasnt been initialized - // the merge_batch should initialize its state - let mut accum2 = agg.create_accumulator()?; - accum2.merge_batch(&states)?; - assert_eq!(accum2.evaluate()?, ScalarValue::Int64(Some(2))); - Ok(()) - } } From 93ea2dbc8c474eb9b8f3f1b8b01c520ed31d6db3 Mon Sep 17 00:00:00 2001 From: Jay Miller <3744812+jaylmiller@users.noreply.github.com> Date: Sun, 12 Mar 2023 13:48:57 -0400 Subject: [PATCH 4/7] organize type alias --- datafusion/physical-expr/src/aggregate/count_distinct.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index a4632a9614eea..d7596148cb02c 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -32,7 +32,7 @@ use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; type DistinctScalarValues = ScalarValue; - +type ValueSet = HashSet; /// Expression for a COUNT(DISTINCT) aggregation. #[derive(Debug)] pub struct DistinctCount { @@ -144,7 +144,7 @@ impl AggregateExpr for DistinctCount { &self.name } } -type ValueSet = HashSet; + // calculating the size of values hashset for fixed length values, // taking first batch size * number of batches. // This method is faster than full_size(), however it is not suitable for variable length @@ -197,7 +197,7 @@ fn values_to_state(values: &ValueSet, datatype: &DataType) -> Result, + values: ValueSet, state_data_type: DataType, } @@ -243,7 +243,7 @@ where /// `K` is required when casting to dict array _dt: core::marker::PhantomData, values_datatype: DataType, - values: HashSet, + values: ValueSet, } impl std::fmt::Debug for CountDistinctDictAccumulator From 3a0a01cb3b0f31e4af72f739640bcced2257d3c3 Mon Sep 17 00:00:00 2001 From: Jay Miller <3744812+jaylmiller@users.noreply.github.com> Date: Tue, 14 Mar 2023 11:42:14 -0400 Subject: [PATCH 5/7] Update datafusion/physical-expr/src/aggregate/count_distinct.rs Co-authored-by: Andrew Lamb --- datafusion/physical-expr/src/aggregate/count_distinct.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index d7596148cb02c..30515bea7d8c1 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -158,6 +158,7 @@ fn values_fixed_size(values: &ValueSet) -> usize { .unwrap_or(0) } // calculates the size as accurate as possible, call to this method is expensive +// but necessary to correctly account for variable length strings fn values_full_size(values: &ValueSet) -> usize { (std::mem::size_of::() * values.capacity()) + values From 3824bd9fba6e935b7cef72eeb5fbae366fb6f65d Mon Sep 17 00:00:00 2001 From: Jay Miller <3744812+jaylmiller@users.noreply.github.com> Date: Tue, 14 Mar 2023 11:50:55 -0400 Subject: [PATCH 6/7] Update datafusion/physical-expr/src/aggregate/count_distinct.rs Co-authored-by: Andrew Lamb --- datafusion/physical-expr/src/aggregate/count_distinct.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 30515bea7d8c1..1db9b49bf184f 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -284,7 +284,7 @@ where let arr = as_dictionary_array::(&values[0])?; let nvalues = arr.values().len(); // map keys to whether their corresponding value has been seen or not - let mut seen_map = (0..nvalues).map(|_| false).collect::>(); + let mut seen_map = vec![(false; nvalues]; for idx in arr.keys_iter().flatten() { seen_map[idx] = true; } From a881c5ac4e8b7ddd55b3d6d8e9fd4385285ee1ef Mon Sep 17 00:00:00 2001 From: Jay Miller <3744812+jaylmiller@users.noreply.github.com> Date: Tue, 14 Mar 2023 12:06:11 -0400 Subject: [PATCH 7/7] suggested changes --- .../src/aggregate/count_distinct.rs | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 1db9b49bf184f..93d8d154d611a 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -31,8 +31,7 @@ use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; -type DistinctScalarValues = ScalarValue; -type ValueSet = HashSet; +type ValueSet = HashSet; /// Expression for a COUNT(DISTINCT) aggregation. #[derive(Debug)] pub struct DistinctCount { @@ -128,8 +127,9 @@ impl AggregateExpr for DistinctCount { datatypes::UInt64Type, >::new(val_type)), _ => { - // just checked that datatype is a valid dict key type - unreachable!() + return Err(DataFusionError::Internal( + "Dict key has invalid datatype".to_string(), + )) } } } @@ -150,7 +150,7 @@ impl AggregateExpr for DistinctCount { // This method is faster than full_size(), however it is not suitable for variable length // values like strings or complex types fn values_fixed_size(values: &ValueSet) -> usize { - (std::mem::size_of::() * values.capacity()) + (std::mem::size_of::() * values.capacity()) + values .iter() .next() @@ -160,7 +160,7 @@ fn values_fixed_size(values: &ValueSet) -> usize { // calculates the size as accurate as possible, call to this method is expensive // but necessary to correctly account for variable length strings fn values_full_size(values: &ValueSet) -> usize { - (std::mem::size_of::() * values.capacity()) + (std::mem::size_of::() * values.capacity()) + values .iter() .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) @@ -284,14 +284,12 @@ where let arr = as_dictionary_array::(&values[0])?; let nvalues = arr.values().len(); // map keys to whether their corresponding value has been seen or not - let mut seen_map = vec![(false; nvalues]; + let mut seen_map = vec![false; nvalues]; for idx in arr.keys_iter().flatten() { - seen_map[idx] = true; - } - for (idx, seen) in seen_map.into_iter().enumerate() { - if seen { + if !seen_map[idx] { let scalar = ScalarValue::try_from_array(arr.values(), idx)?; self.values.insert(scalar); + seen_map[idx] = true; } } Ok(())