From 581fcfd0b2ae98aa99f5cea76384d2a85e23d76b Mon Sep 17 00:00:00 2001 From: Fedomn Date: Wed, 21 Sep 2022 22:22:47 +0800 Subject: [PATCH 1/2] feat(distinct): add DistinctSumAccumulator to support distinct sum Signed-off-by: Fedomn --- Cargo.lock | 10 ++++++ Cargo.toml | 1 + src/binder/expression/agg_func.rs | 12 ++++++- src/binder/expression/mod.rs | 1 + src/executor/aggregate/hash_agg.rs | 1 + src/executor/aggregate/mod.rs | 15 +++++---- src/executor/aggregate/sum.rs | 38 ++++++++++++++++++++++ src/optimizer/input_ref_rewriter.rs | 1 + src/optimizer/rules/column_pruning.rs | 5 ++- src/types/mod.rs | 47 ++++++++++++++++++++++++++- tests/slt/distinct.slt | 31 ++++++++++++++++++ tests/slt/select.slt | 14 -------- 12 files changed, 153 insertions(+), 23 deletions(-) create mode 100644 tests/slt/distinct.slt diff --git a/Cargo.lock b/Cargo.lock index 921e01f..6ed4e81 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -890,6 +890,15 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7709cef83f0c1f58f666e746a08b21e0085f7440fa6a29cc194d68aac97a4225" +[[package]] +name = "ordered-float" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98ffdb14730ed2ef599c65810c15b000896e21e8776b512de0db0c3d7335cc2a" +dependencies = [ + "num-traits", +] + [[package]] name = "os_str_bytes" version = "6.2.0" @@ -1256,6 +1265,7 @@ dependencies = [ "futures-async-stream", "itertools", "lazy_static", + "ordered-float", "paste", "petgraph", "pretty_assertions", diff --git a/Cargo.toml b/Cargo.toml index 547234f..b007477 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ enum_dispatch = "0.3" lazy_static = "1" strum = "0.24" strum_macros = "0.24" +ordered-float = "3.0" [dev-dependencies] test-case = "2" diff --git a/src/binder/expression/agg_func.rs b/src/binder/expression/agg_func.rs index d328194..ed70783 100644 --- a/src/binder/expression/agg_func.rs +++ b/src/binder/expression/agg_func.rs @@ -30,6 +30,7 @@ pub struct BoundAggFunc { pub func: AggFunc, pub exprs: Vec, pub return_type: DataType, + pub distinct: bool, } impl Binder { @@ -55,21 +56,25 @@ impl Binder { func: AggFunc::Count, exprs: args.clone(), return_type: DataType::Int64, + distinct: func.distinct, }, "sum" => BoundAggFunc { func: AggFunc::Sum, exprs: args.clone(), return_type: args[0].return_type().unwrap(), + distinct: func.distinct, }, "min" => BoundAggFunc { func: AggFunc::Min, exprs: args.clone(), return_type: args[0].return_type().unwrap(), + distinct: func.distinct, }, "max" => BoundAggFunc { func: AggFunc::Max, exprs: args.clone(), return_type: args[0].return_type().unwrap(), + distinct: func.distinct, }, _ => unimplemented!("not implmented agg func {}", func.name), }; @@ -84,6 +89,11 @@ impl fmt::Debug for BoundAggFunc { } else { format!("{:?}", self.exprs) }; - write!(f, "{}({}):{}", self.func, expr, self.return_type) + let distinct = if self.distinct { "Distinct" } else { "" }; + write!( + f, + "{}{}({}):{}", + distinct, self.func, expr, self.return_type + ) } } diff --git a/src/binder/expression/mod.rs b/src/binder/expression/mod.rs index 7f34240..ce790df 100644 --- a/src/binder/expression/mod.rs +++ b/src/binder/expression/mod.rs @@ -102,6 +102,7 @@ impl Binder { Expr::UnaryOp { op: _, expr: _ } => todo!(), Expr::Value(v) => Ok(BoundExpr::Constant(v.into())), Expr::Function(func) => self.bind_agg_func(func), + Expr::Nested(expr) => self.bind_expr(expr), _ => todo!("unsupported expr {:?}", expr), } } diff --git a/src/executor/aggregate/hash_agg.rs b/src/executor/aggregate/hash_agg.rs index be28204..30e1fe9 100644 --- a/src/executor/aggregate/hash_agg.rs +++ b/src/executor/aggregate/hash_agg.rs @@ -195,6 +195,7 @@ mod tests { func: AggFunc::Sum, exprs: vec![build_bound_input_ref(1)], return_type: DataType::Int64, + distinct: false, })]; let group_by = vec![build_bound_input_ref(0)]; diff --git a/src/executor/aggregate/mod.rs b/src/executor/aggregate/mod.rs index 69f4fdd..bc0c2f9 100644 --- a/src/executor/aggregate/mod.rs +++ b/src/executor/aggregate/mod.rs @@ -2,7 +2,7 @@ use arrow::array::ArrayRef; use self::count::CountAccumulator; use self::min_max::{MaxAccumulator, MinAccumulator}; -use self::sum::SumAccumulator; +use self::sum::{DistinctSumAccumulator, SumAccumulator}; use super::ExecutorError; use crate::binder::{AggFunc, BoundExpr}; use crate::types::ScalarValue; @@ -26,11 +26,14 @@ pub trait Accumulator: Send + Sync { fn create_accumulator(expr: &BoundExpr) -> Box { if let BoundExpr::AggFunc(agg_expr) = expr { - match agg_expr.func { - AggFunc::Count => Box::new(CountAccumulator::new()), - AggFunc::Sum => Box::new(SumAccumulator::new(agg_expr.return_type.clone())), - AggFunc::Min => Box::new(MinAccumulator::new(agg_expr.return_type.clone())), - AggFunc::Max => Box::new(MaxAccumulator::new(agg_expr.return_type.clone())), + match (&agg_expr.func, &agg_expr.distinct) { + (AggFunc::Count, _) => Box::new(CountAccumulator::new()), + (AggFunc::Sum, false) => Box::new(SumAccumulator::new(agg_expr.return_type.clone())), + (AggFunc::Sum, true) => { + Box::new(DistinctSumAccumulator::new(agg_expr.return_type.clone())) + } + (AggFunc::Min, _) => Box::new(MinAccumulator::new(agg_expr.return_type.clone())), + (AggFunc::Max, _) => Box::new(MaxAccumulator::new(agg_expr.return_type.clone())), } } else { unreachable!( diff --git a/src/executor/aggregate/sum.rs b/src/executor/aggregate/sum.rs index 047bda0..0873836 100644 --- a/src/executor/aggregate/sum.rs +++ b/src/executor/aggregate/sum.rs @@ -1,5 +1,8 @@ // most of ideas inspired by datafusion +use std::collections::HashSet; + +use ahash::RandomState; use arrow::array::{ArrayRef, Float64Array, Int32Array, Int64Array}; use arrow::compute; use arrow::compute::kernels::cast::cast; @@ -92,3 +95,38 @@ impl Accumulator for SumAccumulator { Ok(self.result.clone()) } } + +pub struct DistinctSumAccumulator { + distinct_values: HashSet, + data_type: DataType, +} + +impl DistinctSumAccumulator { + pub fn new(data_type: DataType) -> Self { + Self { + distinct_values: HashSet::default(), + data_type, + } + } +} + +impl Accumulator for DistinctSumAccumulator { + fn update_batch(&mut self, array: &ArrayRef) -> Result<(), ExecutorError> { + if array.is_empty() { + return Ok(()); + } + (0..array.len()).for_each(|i| { + let v = ScalarValue::try_from_array(array, i); + self.distinct_values.insert(v); + }); + Ok(()) + } + + fn evaluate(&self) -> Result { + let mut sum = ScalarValue::from(&self.data_type); + for v in self.distinct_values.iter() { + sum = sum_result(&sum, v); + } + Ok(sum) + } +} diff --git a/src/optimizer/input_ref_rewriter.rs b/src/optimizer/input_ref_rewriter.rs index c444c5a..dbb9284 100644 --- a/src/optimizer/input_ref_rewriter.rs +++ b/src/optimizer/input_ref_rewriter.rs @@ -331,6 +331,7 @@ mod input_ref_rewriter_test { func: AggFunc::Sum, exprs: vec![build_bound_column_ref("t", "c1")], return_type: DataType::Int32, + distinct: false, }); let simple_agg = LogicalAgg::new(vec![expr.clone()], vec![], input); LogicalProject::new(vec![expr], Arc::new(simple_agg)) diff --git a/src/optimizer/rules/column_pruning.rs b/src/optimizer/rules/column_pruning.rs index a239bc5..f38ae51 100644 --- a/src/optimizer/rules/column_pruning.rs +++ b/src/optimizer/rules/column_pruning.rs @@ -202,7 +202,10 @@ impl Rule for RemoveNoopOperators { let child_plan_ref = &child_opt_expr.get_plan_ref(); let child_exprs = match child_plan_ref.node_type() { PlanNodeType::LogicalProject => child_plan_ref.as_logical_project().unwrap().exprs(), - PlanNodeType::LogicalAgg => child_plan_ref.as_logical_agg().unwrap().agg_funcs(), + PlanNodeType::LogicalAgg => { + let plan = child_plan_ref.as_logical_agg().unwrap(); + [plan.group_by(), plan.agg_funcs()].concat() + } _other => { unreachable!("RemoveNoopOperators not supprt type: {:?}", _other); } diff --git a/src/types/mod.rs b/src/types/mod.rs index f171201..942e781 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use arrow::array::*; use arrow::datatypes::DataType; use arrow::error::ArrowError; +use ordered_float::OrderedFloat; macro_rules! typed_cast { ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ @@ -18,7 +19,7 @@ macro_rules! typed_cast { /// To keep simplicity, we only support some scalar value /// Represents a dynamically typed, nullable single value. /// This is the single-valued counter-part of arrow’s `Array`. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] pub enum ScalarValue { /// represents `DataType::Null` (castable to/from any other type) Null, @@ -166,6 +167,50 @@ impl From<&sqlparser::ast::Value> for ScalarValue { } } +impl PartialEq for ScalarValue { + fn eq(&self, other: &Self) -> bool { + use ScalarValue::*; + + match (self, other) { + (Null, Null) => true, + (Null, _) => false, + (Boolean(v1), Boolean(v2)) => v1.eq(v2), + (Boolean(_), _) => false, + (Float64(v1), Float64(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.eq(&v2) + } + (Float64(_), _) => false, + (Int32(v1), Int32(v2)) => v1.eq(v2), + (Int32(_), _) => false, + (Int64(v1), Int64(v2)) => v1.eq(v2), + (Int64(_), _) => false, + (String(v1), String(v2)) => v1.eq(v2), + (String(_), _) => false, + } + } +} + +impl Eq for ScalarValue {} + +impl std::hash::Hash for ScalarValue { + fn hash(&self, state: &mut H) { + match self { + // stable hash for Null value + ScalarValue::Null => 1.hash(state), + ScalarValue::Boolean(v) => v.hash(state), + ScalarValue::Float64(v) => { + // f64 not implement Hash, see https://internals.rust-lang.org/t/f32-f64-should-implement-hash/5436/3 + v.map(OrderedFloat).hash(state); + } + ScalarValue::Int32(v) => v.hash(state), + ScalarValue::Int64(v) => v.hash(state), + ScalarValue::String(v) => v.hash(state), + } + } +} + pub fn build_scalar_value_array(scalar_value: &ScalarValue, capacity: usize) -> ArrayRef { match scalar_value { ScalarValue::Null => new_null_array(&DataType::Null, capacity), diff --git a/tests/slt/distinct.slt b/tests/slt/distinct.slt new file mode 100644 index 0000000..66f696f --- /dev/null +++ b/tests/slt/distinct.slt @@ -0,0 +1,31 @@ +query I +select distinct state from employee; +---- +CA +CO +(empty) + +query II +select distinct a, b from t2; +---- +10 2 +20 2 +30 3 +40 4 + +query I +select sum(distinct b) from t2; +---- +9 + +query I +select sum(distinct(b)) from t2; +---- +9 + +query I +select sum(distinct(b)) from t2 group by c; +---- +2 +2 +7 diff --git a/tests/slt/select.slt b/tests/slt/select.slt index 3d28ead..d975a4d 100644 --- a/tests/slt/select.slt +++ b/tests/slt/select.slt @@ -6,17 +6,3 @@ Gregg CO 2 10000 John CO 3 11500 Von (empty) 4 NULL -query I -select distinct state from employee ----- -CA -CO -(empty) - -query II -select distinct a, b from t2 ----- -10 2 -20 2 -30 3 -40 4 From ccdab9334866dabe3fc600032139bae3a0fa24a8 Mon Sep 17 00:00:00 2001 From: Fedomn Date: Thu, 22 Sep 2022 21:22:02 +0800 Subject: [PATCH 2/2] feat(distinct): add DistinctCountAccumulator to support distinct count Signed-off-by: Fedomn --- src/executor/aggregate/count.rs | 32 ++++++++++++++++++++++++++++++++ src/executor/aggregate/mod.rs | 5 +++-- tests/slt/distinct.slt | 5 +++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/executor/aggregate/count.rs b/src/executor/aggregate/count.rs index cd06f61..ac9e216 100644 --- a/src/executor/aggregate/count.rs +++ b/src/executor/aggregate/count.rs @@ -1,3 +1,6 @@ +use std::collections::HashSet; + +use ahash::RandomState; use arrow::array::ArrayRef; use super::Accumulator; @@ -24,3 +27,32 @@ impl Accumulator for CountAccumulator { Ok(ScalarValue::Int64(Some(self.result))) } } + +pub struct DistinctCountAccumulator { + distinct_values: HashSet, +} + +impl DistinctCountAccumulator { + pub fn new() -> Self { + Self { + distinct_values: HashSet::default(), + } + } +} + +impl Accumulator for DistinctCountAccumulator { + fn update_batch(&mut self, array: &ArrayRef) -> Result<(), ExecutorError> { + if array.is_empty() { + return Ok(()); + } + (0..array.len()).for_each(|i| { + let v = ScalarValue::try_from_array(array, i); + self.distinct_values.insert(v); + }); + Ok(()) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::Int64(Some(self.distinct_values.len() as i64))) + } +} diff --git a/src/executor/aggregate/mod.rs b/src/executor/aggregate/mod.rs index bc0c2f9..b43940a 100644 --- a/src/executor/aggregate/mod.rs +++ b/src/executor/aggregate/mod.rs @@ -1,6 +1,6 @@ use arrow::array::ArrayRef; -use self::count::CountAccumulator; +use self::count::{CountAccumulator, DistinctCountAccumulator}; use self::min_max::{MaxAccumulator, MinAccumulator}; use self::sum::{DistinctSumAccumulator, SumAccumulator}; use super::ExecutorError; @@ -27,7 +27,8 @@ pub trait Accumulator: Send + Sync { fn create_accumulator(expr: &BoundExpr) -> Box { if let BoundExpr::AggFunc(agg_expr) = expr { match (&agg_expr.func, &agg_expr.distinct) { - (AggFunc::Count, _) => Box::new(CountAccumulator::new()), + (AggFunc::Count, false) => Box::new(CountAccumulator::new()), + (AggFunc::Count, true) => Box::new(DistinctCountAccumulator::new()), (AggFunc::Sum, false) => Box::new(SumAccumulator::new(agg_expr.return_type.clone())), (AggFunc::Sum, true) => { Box::new(DistinctSumAccumulator::new(agg_expr.return_type.clone())) diff --git a/tests/slt/distinct.slt b/tests/slt/distinct.slt index 66f696f..6364114 100644 --- a/tests/slt/distinct.slt +++ b/tests/slt/distinct.slt @@ -29,3 +29,8 @@ select sum(distinct(b)) from t2 group by c; 2 2 7 + +query I +select count(distinct(b)) from t2; +---- +3