Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ enum_dispatch = "0.3"
lazy_static = "1"
strum = "0.24"
strum_macros = "0.24"
ordered-float = "3.0"

[dev-dependencies]
test-case = "2"
Expand Down
12 changes: 11 additions & 1 deletion src/binder/expression/agg_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub struct BoundAggFunc {
pub func: AggFunc,
pub exprs: Vec<BoundExpr>,
pub return_type: DataType,
pub distinct: bool,
}

impl Binder {
Expand All @@ -55,21 +56,25 @@ impl Binder {
func: AggFunc::Count,
exprs: args.clone(),
return_type: DataType::Int64,
distinct: func.distinct,
},
"sum" => BoundAggFunc {
func: AggFunc::Sum,
exprs: args.clone(),
return_type: args[0].return_type().unwrap(),
distinct: func.distinct,
},
"min" => BoundAggFunc {
func: AggFunc::Min,
exprs: args.clone(),
return_type: args[0].return_type().unwrap(),
distinct: func.distinct,
},
"max" => BoundAggFunc {
func: AggFunc::Max,
exprs: args.clone(),
return_type: args[0].return_type().unwrap(),
distinct: func.distinct,
},
_ => unimplemented!("not implmented agg func {}", func.name),
};
Expand All @@ -84,6 +89,11 @@ impl fmt::Debug for BoundAggFunc {
} else {
format!("{:?}", self.exprs)
};
write!(f, "{}({}):{}", self.func, expr, self.return_type)
let distinct = if self.distinct { "Distinct" } else { "" };
write!(
f,
"{}{}({}):{}",
distinct, self.func, expr, self.return_type
)
}
}
1 change: 1 addition & 0 deletions src/binder/expression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ impl Binder {
Expr::UnaryOp { op: _, expr: _ } => todo!(),
Expr::Value(v) => Ok(BoundExpr::Constant(v.into())),
Expr::Function(func) => self.bind_agg_func(func),
Expr::Nested(expr) => self.bind_expr(expr),
_ => todo!("unsupported expr {:?}", expr),
}
}
Expand Down
32 changes: 32 additions & 0 deletions src/executor/aggregate/count.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::collections::HashSet;

use ahash::RandomState;
use arrow::array::ArrayRef;

use super::Accumulator;
Expand All @@ -24,3 +27,32 @@ impl Accumulator for CountAccumulator {
Ok(ScalarValue::Int64(Some(self.result)))
}
}

pub struct DistinctCountAccumulator {
distinct_values: HashSet<ScalarValue, RandomState>,
}

impl DistinctCountAccumulator {
pub fn new() -> Self {
Self {
distinct_values: HashSet::default(),
}
}
}

impl Accumulator for DistinctCountAccumulator {
fn update_batch(&mut self, array: &ArrayRef) -> Result<(), ExecutorError> {
if array.is_empty() {
return Ok(());
}
(0..array.len()).for_each(|i| {
let v = ScalarValue::try_from_array(array, i);
self.distinct_values.insert(v);
});
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue, ExecutorError> {
Ok(ScalarValue::Int64(Some(self.distinct_values.len() as i64)))
}
}
1 change: 1 addition & 0 deletions src/executor/aggregate/hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ mod tests {
func: AggFunc::Sum,
exprs: vec![build_bound_input_ref(1)],
return_type: DataType::Int64,
distinct: false,
})];

let group_by = vec![build_bound_input_ref(0)];
Expand Down
18 changes: 11 additions & 7 deletions src/executor/aggregate/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use arrow::array::ArrayRef;

use self::count::CountAccumulator;
use self::count::{CountAccumulator, DistinctCountAccumulator};
use self::min_max::{MaxAccumulator, MinAccumulator};
use self::sum::SumAccumulator;
use self::sum::{DistinctSumAccumulator, SumAccumulator};
use super::ExecutorError;
use crate::binder::{AggFunc, BoundExpr};
use crate::types::ScalarValue;
Expand All @@ -26,11 +26,15 @@ pub trait Accumulator: Send + Sync {

fn create_accumulator(expr: &BoundExpr) -> Box<dyn Accumulator> {
if let BoundExpr::AggFunc(agg_expr) = expr {
match agg_expr.func {
AggFunc::Count => Box::new(CountAccumulator::new()),
AggFunc::Sum => Box::new(SumAccumulator::new(agg_expr.return_type.clone())),
AggFunc::Min => Box::new(MinAccumulator::new(agg_expr.return_type.clone())),
AggFunc::Max => Box::new(MaxAccumulator::new(agg_expr.return_type.clone())),
match (&agg_expr.func, &agg_expr.distinct) {
(AggFunc::Count, false) => Box::new(CountAccumulator::new()),
(AggFunc::Count, true) => Box::new(DistinctCountAccumulator::new()),
(AggFunc::Sum, false) => Box::new(SumAccumulator::new(agg_expr.return_type.clone())),
(AggFunc::Sum, true) => {
Box::new(DistinctSumAccumulator::new(agg_expr.return_type.clone()))
}
(AggFunc::Min, _) => Box::new(MinAccumulator::new(agg_expr.return_type.clone())),
(AggFunc::Max, _) => Box::new(MaxAccumulator::new(agg_expr.return_type.clone())),
}
} else {
unreachable!(
Expand Down
38 changes: 38 additions & 0 deletions src/executor/aggregate/sum.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
// most of ideas inspired by datafusion

use std::collections::HashSet;

use ahash::RandomState;
use arrow::array::{ArrayRef, Float64Array, Int32Array, Int64Array};
use arrow::compute;
use arrow::compute::kernels::cast::cast;
Expand Down Expand Up @@ -92,3 +95,38 @@ impl Accumulator for SumAccumulator {
Ok(self.result.clone())
}
}

pub struct DistinctSumAccumulator {
distinct_values: HashSet<ScalarValue, RandomState>,
data_type: DataType,
}

impl DistinctSumAccumulator {
pub fn new(data_type: DataType) -> Self {
Self {
distinct_values: HashSet::default(),
data_type,
}
}
}

impl Accumulator for DistinctSumAccumulator {
fn update_batch(&mut self, array: &ArrayRef) -> Result<(), ExecutorError> {
if array.is_empty() {
return Ok(());
}
(0..array.len()).for_each(|i| {
let v = ScalarValue::try_from_array(array, i);
self.distinct_values.insert(v);
});
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue, ExecutorError> {
let mut sum = ScalarValue::from(&self.data_type);
for v in self.distinct_values.iter() {
sum = sum_result(&sum, v);
}
Ok(sum)
}
}
1 change: 1 addition & 0 deletions src/optimizer/input_ref_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ mod input_ref_rewriter_test {
func: AggFunc::Sum,
exprs: vec![build_bound_column_ref("t", "c1")],
return_type: DataType::Int32,
distinct: false,
});
let simple_agg = LogicalAgg::new(vec![expr.clone()], vec![], input);
LogicalProject::new(vec![expr], Arc::new(simple_agg))
Expand Down
5 changes: 4 additions & 1 deletion src/optimizer/rules/column_pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,10 @@ impl Rule for RemoveNoopOperators {
let child_plan_ref = &child_opt_expr.get_plan_ref();
let child_exprs = match child_plan_ref.node_type() {
PlanNodeType::LogicalProject => child_plan_ref.as_logical_project().unwrap().exprs(),
PlanNodeType::LogicalAgg => child_plan_ref.as_logical_agg().unwrap().agg_funcs(),
PlanNodeType::LogicalAgg => {
let plan = child_plan_ref.as_logical_agg().unwrap();
[plan.group_by(), plan.agg_funcs()].concat()
}
_other => {
unreachable!("RemoveNoopOperators not supprt type: {:?}", _other);
}
Expand Down
47 changes: 46 additions & 1 deletion src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;
use arrow::array::*;
use arrow::datatypes::DataType;
use arrow::error::ArrowError;
use ordered_float::OrderedFloat;

macro_rules! typed_cast {
($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{
Expand All @@ -18,7 +19,7 @@ macro_rules! typed_cast {
/// To keep simplicity, we only support some scalar value
/// Represents a dynamically typed, nullable single value.
/// This is the single-valued counter-part of arrow’s `Array`.
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug)]
pub enum ScalarValue {
/// represents `DataType::Null` (castable to/from any other type)
Null,
Expand Down Expand Up @@ -166,6 +167,50 @@ impl From<&sqlparser::ast::Value> for ScalarValue {
}
}

impl PartialEq for ScalarValue {
fn eq(&self, other: &Self) -> bool {
use ScalarValue::*;

match (self, other) {
(Null, Null) => true,
(Null, _) => false,
(Boolean(v1), Boolean(v2)) => v1.eq(v2),
(Boolean(_), _) => false,
(Float64(v1), Float64(v2)) => {
let v1 = v1.map(OrderedFloat);
let v2 = v2.map(OrderedFloat);
v1.eq(&v2)
}
(Float64(_), _) => false,
(Int32(v1), Int32(v2)) => v1.eq(v2),
(Int32(_), _) => false,
(Int64(v1), Int64(v2)) => v1.eq(v2),
(Int64(_), _) => false,
(String(v1), String(v2)) => v1.eq(v2),
(String(_), _) => false,
}
}
}

impl Eq for ScalarValue {}

impl std::hash::Hash for ScalarValue {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self {
// stable hash for Null value
ScalarValue::Null => 1.hash(state),
ScalarValue::Boolean(v) => v.hash(state),
ScalarValue::Float64(v) => {
// f64 not implement Hash, see https://internals.rust-lang.org/t/f32-f64-should-implement-hash/5436/3
v.map(OrderedFloat).hash(state);
}
ScalarValue::Int32(v) => v.hash(state),
ScalarValue::Int64(v) => v.hash(state),
ScalarValue::String(v) => v.hash(state),
}
}
}

pub fn build_scalar_value_array(scalar_value: &ScalarValue, capacity: usize) -> ArrayRef {
match scalar_value {
ScalarValue::Null => new_null_array(&DataType::Null, capacity),
Expand Down
36 changes: 36 additions & 0 deletions tests/slt/distinct.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
query I
select distinct state from employee;
----
CA
CO
(empty)

query II
select distinct a, b from t2;
----
10 2
20 2
30 3
40 4

query I
select sum(distinct b) from t2;
----
9

query I
select sum(distinct(b)) from t2;
----
9

query I
select sum(distinct(b)) from t2 group by c;
----
2
2
7

query I
select count(distinct(b)) from t2;
----
3
14 changes: 0 additions & 14 deletions tests/slt/select.slt
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,3 @@ Gregg CO 2 10000
John CO 3 11500
Von (empty) 4 NULL

query I
select distinct state from employee
----
CA
CO
(empty)

query II
select distinct a, b from t2
----
10 2
20 2
30 3
40 4