diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 803870f3f7840..77da95c3a04a3 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -40,6 +40,7 @@ mod literal; mod min_max; mod negative; mod not; +mod nth_value; mod nullif; mod row_number; mod sum; @@ -58,6 +59,7 @@ pub use literal::{lit, Literal}; pub use min_max::{Max, Min}; pub use negative::{negative, NegativeExpr}; pub use not::{not, NotExpr}; +pub use nth_value::{FirstValue, LastValue, NthValue}; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use row_number::RowNumber; pub use sum::{sum_return_type, Sum}; diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs b/datafusion/src/physical_plan/expressions/nth_value.rs new file mode 100644 index 0000000000000..e90ad322aae9d --- /dev/null +++ b/datafusion/src/physical_plan/expressions/nth_value.rs @@ -0,0 +1,223 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{ + window_functions::BuiltInWindowFunctionExpr, PhysicalExpr, WindowAccumulator, +}; +use crate::scalar::ScalarValue; +use arrow::datatypes::{DataType, Field}; +use std::any::Any; +use std::convert::TryFrom; +use std::sync::Arc; + +/// first_value expression +#[derive(Debug)] +pub struct FirstValue { + name: String, + data_type: DataType, + expr: Arc, +} + +impl FirstValue { + /// Create a new FIRST_VALUE window aggregate function + pub fn new(expr: Arc, name: String, data_type: DataType) -> Self { + Self { + name, + data_type, + expr, + } + } +} + +impl BuiltInWindowFunctionExpr for FirstValue { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + let nullable = true; + Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(NthValueAccumulator::try_new( + 1, + self.data_type.clone(), + )?)) + } +} + +// sql values start with 1, so we can use 0 to indicate the special last value behavior +const SPECIAL_SIZE_VALUE_FOR_LAST: u32 = 0; + +/// last_value expression +#[derive(Debug)] +pub struct LastValue { + name: String, + data_type: DataType, + expr: Arc, +} + +impl LastValue { + /// Create a new FIRST_VALUE window aggregate function + pub fn new(expr: Arc, name: String, data_type: DataType) -> Self { + Self { + name, + data_type, + expr, + } + } +} + +impl BuiltInWindowFunctionExpr for LastValue { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + let nullable = true; + Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(NthValueAccumulator::try_new( + SPECIAL_SIZE_VALUE_FOR_LAST, + self.data_type.clone(), + )?)) + } +} + +/// nth_value expression +#[derive(Debug)] +pub struct NthValue { + name: String, + n: u32, + data_type: DataType, + expr: Arc, +} + +impl NthValue { + /// Create a new NTH_VALUE window aggregate function + pub fn try_new( + expr: Arc, + name: String, + n: u32, + data_type: DataType, + ) -> Result { + if n == SPECIAL_SIZE_VALUE_FOR_LAST { + Err(DataFusionError::Execution( + "nth_value expect n to be > 0".to_owned(), + )) + } else { + Ok(Self { + name, + n, + data_type, + expr, + }) + } + } +} + +impl BuiltInWindowFunctionExpr for NthValue { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + let nullable = true; + Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(NthValueAccumulator::try_new( + self.n, + self.data_type.clone(), + )?)) + } +} + +#[derive(Debug)] +struct NthValueAccumulator { + // n the target nth_value, however we'll reuse it for last_value acc, so when n == 0 it specifically + // means last; also note that it is totally valid for n to be larger than the number of rows input + // in which case all the values shall be null + n: u32, + offset: u32, + value: ScalarValue, +} + +impl NthValueAccumulator { + /// new count accumulator + pub fn try_new(n: u32, data_type: DataType) -> Result { + Ok(Self { + n, + offset: 0, + // null value of that data_type by default + value: ScalarValue::try_from(&data_type)?, + }) + } +} + +impl WindowAccumulator for NthValueAccumulator { + fn scan(&mut self, values: &[ScalarValue]) -> Result> { + if self.n == SPECIAL_SIZE_VALUE_FOR_LAST { + // for last_value function + self.value = values[0].clone(); + } else if self.offset < self.n { + self.offset += 1; + if self.offset == self.n { + self.value = values[0].clone(); + } + } + Ok(None) + } + + fn evaluate(&self) -> Result> { + Ok(Some(self.value.clone())) + } +} diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index 8ced3aec8ec11..e790eeaca749e 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -20,7 +20,9 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ aggregates, - expressions::RowNumber, + expressions::{FirstValue, LastValue, Literal, NthValue, RowNumber}, + type_coercion::coerce, + window_functions::signature_for_built_in, window_functions::BuiltInWindowFunctionExpr, window_functions::{BuiltInWindowFunction, WindowFunction}, Accumulator, AggregateExpr, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, @@ -39,6 +41,7 @@ use futures::stream::{Stream, StreamExt}; use futures::Future; use pin_project_lite::pin_project; use std::any::Any; +use std::convert::TryInto; use std::iter; use std::pin::Pin; use std::sync::Arc; @@ -82,12 +85,40 @@ pub fn create_window_expr( fn create_built_in_window_expr( fun: &BuiltInWindowFunction, - _args: &[Arc], - _input_schema: &Schema, + args: &[Arc], + input_schema: &Schema, name: String, ) -> Result> { match fun { BuiltInWindowFunction::RowNumber => Ok(Arc::new(RowNumber::new(name))), + BuiltInWindowFunction::NthValue => { + let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?; + let arg = coerced_args[0].clone(); + let n = coerced_args[1] + .as_any() + .downcast_ref::() + .unwrap() + .value(); + let n: i64 = n + .clone() + .try_into() + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; + let n: u32 = n as u32; + let data_type = args[0].data_type(input_schema)?; + Ok(Arc::new(NthValue::try_new(arg, name, n, data_type)?)) + } + BuiltInWindowFunction::FirstValue => { + let arg = + coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone(); + let data_type = args[0].data_type(input_schema)?; + Ok(Arc::new(FirstValue::new(arg, name, data_type))) + } + BuiltInWindowFunction::LastValue => { + let arg = + coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone(); + let data_type = args[0].data_type(input_schema)?; + Ok(Arc::new(LastValue::new(arg, name, data_type))) + } _ => Err(DataFusionError::NotImplemented(format!( "Window function with {:?} not yet implemented", fun @@ -484,45 +515,106 @@ impl RecordBatchStream for WindowAggStream { #[cfg(test)] mod tests { - // use super::*; - - // /// some mock data to test windows - // fn some_data() -> (Arc, Vec) { - // // define a schema. - // let schema = Arc::new(Schema::new(vec![ - // Field::new("a", DataType::UInt32, false), - // Field::new("b", DataType::Float64, false), - // ])); - - // // define data. - // ( - // schema.clone(), - // vec![ - // RecordBatch::try_new( - // schema.clone(), - // vec![ - // Arc::new(UInt32Array::from(vec![2, 3, 4, 4])), - // Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), - // ], - // ) - // .unwrap(), - // RecordBatch::try_new( - // schema, - // vec![ - // Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), - // Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), - // ], - // ) - // .unwrap(), - // ], - // ) - // } - - // #[tokio::test] - // async fn window_function() -> Result<()> { - // let input: Arc = unimplemented!(); - // let input_schema = input.schema(); - // let window_expr = vec![]; - // WindowAggExec::try_new(window_expr, input, input_schema); - // } + use super::*; + use crate::physical_plan::aggregates::AggregateFunction; + use crate::physical_plan::collect; + use crate::physical_plan::csv::{CsvExec, CsvReadOptions}; + use crate::physical_plan::expressions::col; + use crate::test; + use arrow::array::*; + + fn create_test_schema(partitions: usize) -> Result<(Arc, SchemaRef)> { + let schema = test::aggr_test_schema(); + let path = test::create_partitioned_csv("aggregate_test_100.csv", partitions)?; + let csv = CsvExec::try_new( + &path, + CsvReadOptions::new().schema(&schema), + None, + 1024, + None, + )?; + + let input = Arc::new(csv); + Ok((input, schema)) + } + + #[tokio::test] + async fn window_function_input_partition() -> Result<()> { + let (input, schema) = create_test_schema(4)?; + + let window_exec = Arc::new(WindowAggExec::try_new( + vec![create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Count), + &[col("c3")], + schema.as_ref(), + "count".to_owned(), + )?], + input, + schema.clone(), + )?); + + let result = collect(window_exec).await; + + assert!(result.is_err()); + if let Some(DataFusionError::Internal(msg)) = result.err() { + assert_eq!( + msg, + "WindowAggExec requires a single input partition".to_owned() + ); + } else { + unreachable!("Expect an internal error to happen"); + } + Ok(()) + } + + #[tokio::test] + async fn window_function() -> Result<()> { + let (input, schema) = create_test_schema(1)?; + + let window_exec = Arc::new(WindowAggExec::try_new( + vec![ + create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Count), + &[col("c3")], + schema.as_ref(), + "count".to_owned(), + )?, + create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Max), + &[col("c3")], + schema.as_ref(), + "max".to_owned(), + )?, + create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Min), + &[col("c3")], + schema.as_ref(), + "min".to_owned(), + )?, + ], + input, + schema.clone(), + )?); + + let result: Vec = collect(window_exec).await?; + assert_eq!(result.len(), 1); + + let columns = result[0].columns(); + + // c3 is small int + + let count: &UInt64Array = as_primitive_array(&columns[0]); + assert_eq!(count.value(0), 100); + assert_eq!(count.value(99), 100); + + let max: &Int8Array = as_primitive_array(&columns[1]); + assert_eq!(max.value(0), 125); + assert_eq!(max.value(99), 125); + + let min: &Int8Array = as_primitive_array(&columns[2]); + assert_eq!(min.value(0), -117); + assert_eq!(min.value(99), -117); + + Ok(()) + } } diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 55bc88eedf9ab..f5b416f789736 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -807,17 +807,20 @@ async fn csv_query_window_with_empty_over() -> Result<()> { avg(c3) over (), \ count(c3) over (), \ max(c3) over (), \ - min(c3) over () \ + min(c3) over (), \ + first_value(c3) over (), \ + last_value(c3) over (), \ + nth_value(c3, 2) over () from aggregate_test_100 \ - order by c2 \ + order by c2 limit 5"; let actual = execute(&mut ctx, sql).await; let expected = vec![ - vec!["1", "781", "7.81", "100", "125", "-117"], - vec!["1", "781", "7.81", "100", "125", "-117"], - vec!["1", "781", "7.81", "100", "125", "-117"], - vec!["1", "781", "7.81", "100", "125", "-117"], - vec!["1", "781", "7.81", "100", "125", "-117"], + vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"], + vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"], + vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"], + vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"], + vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"], ]; assert_eq!(expected, actual); Ok(())