From 0c5efacbdd03a14d192e485f9f73717103b9e64f Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Fri, 21 May 2021 19:44:33 +0800 Subject: [PATCH] Squashed commit of the following: commit 7fb3640e733bfbbdbf18d58000896f378ba9644c Author: Jiayu Liu Date: Fri May 21 16:38:25 2021 +0800 row number done commit 17239267cd2fbcbb676d5731beeffd0321bbd3ba Author: Jiayu Liu Date: Fri May 21 16:05:50 2021 +0800 add row number commit bf5b8a56f6f33d8eedf3e3009e7fcdb3c388ea5b Author: Jiayu Liu Date: Fri May 21 15:04:49 2021 +0800 save commit d2ce852ead5d8ae3d15962b4dd3062e24bce51de Author: Jiayu Liu Date: Fri May 21 14:53:05 2021 +0800 add streams commit 0a861a76bde0bb43e5561f1cf1ef14fd64e0c08b Author: Jiayu Liu Date: Thu May 20 22:28:34 2021 +0800 save stream commit a9121af7e2e9104d0e4b6ca3ef4f484aaf8baf42 Author: Jiayu Liu Date: Thu May 20 22:01:51 2021 +0800 update unit test commit 2af2a270262ff1bc759af39153d7cd681c32dc0a Author: Jiayu Liu Date: Fri May 21 14:25:12 2021 +0800 fix unit test commit bb57c762b0a1fabc35e207e681bca2bfff7fcf01 Author: Jiayu Liu Date: Fri May 21 14:23:34 2021 +0800 use upper case commit 5d96e525f587fbfaf3e5e9762c9bb10315fcbc3a Author: Jiayu Liu Date: Fri May 21 14:16:16 2021 +0800 fix unit test commit 1ecae8f6cbc6c1898ccf0b38b1e596b6c2e9bb46 Author: Jiayu Liu Date: Fri May 21 12:27:26 2021 +0800 fix unit test commit bc2271d58fd4a9a9cc96126f8abcd6e8f10272ca Author: Jiayu Liu Date: Fri May 21 10:04:29 2021 +0800 fix error commit 880b94f6e27df61b4d3877366f71a51b9b2f5d5d Author: Jiayu Liu Date: Fri May 21 08:24:00 2021 +0800 fix unit test commit 4e792e123a33fd0dcb5f701c679566b55589b0c0 Author: Jiayu Liu Date: Fri May 21 08:05:17 2021 +0800 fix test commit c36c04abf06c74d016597983bf3d3a2a5b5cbdd5 Author: Jiayu Liu Date: Fri May 21 00:07:54 2021 +0800 add more tests commit f5e64de7192a1916df78a4c2fbab7d471c906720 Author: Jiayu Liu Date: Thu May 20 23:41:36 2021 +0800 update commit a1eae864926a6acfeeebe995a12de4ad725ea869 Author: Jiayu Liu Date: Thu May 20 23:36:15 2021 +0800 enrich unit test commit 0d2a214131fe69e19e22144c68fbb992228db6b3 Author: Jiayu Liu Date: Thu May 20 23:25:43 2021 +0800 adding filter by todo commit 8b486d53b09ff1c7a6b9cf4687796ba1c13d6160 Author: Jiayu Liu Date: Thu May 20 23:17:22 2021 +0800 adding more built-in functions commit abf08cd137a80c1381af7de9ae2b3dab05cb4512 Author: Jiayu Liu Date: Thu May 20 22:36:27 2021 +0800 Update datafusion/src/physical_plan/window_functions.rs Co-authored-by: Andrew Lamb commit 0cbca53dac642233520f7d32289b1dfad77b882e Author: Jiayu Liu Date: Thu May 20 22:34:57 2021 +0800 Update datafusion/src/physical_plan/window_functions.rs Co-authored-by: Andrew Lamb commit 831c069f02236a953653b8f1ca25124e393ce20b Author: Jiayu Liu Date: Thu May 20 22:34:04 2021 +0800 Update datafusion/src/logical_plan/builder.rs Co-authored-by: Andrew Lamb commit f70c739fd40e30c4b476253e58b24b9297b42859 Author: Jiayu Liu Date: Thu May 20 22:33:04 2021 +0800 Update datafusion/src/logical_plan/builder.rs Co-authored-by: Andrew Lamb commit 3ee87aa3477c160f17a86628d71a353e03d736b3 Author: Jiayu Liu Date: Wed May 19 22:55:08 2021 +0800 fix unit test commit 5c4d92dc9f570ba6919d84cb8ac70a736d73f40f Author: Jiayu Liu Date: Wed May 19 22:48:26 2021 +0800 fix clippy commit a0b7526c413abbdd4aadab4af8ca9ad8f323f03b Author: Jiayu Liu Date: Wed May 19 22:46:38 2021 +0800 fix unused imports commit 1d3b076acc1c0f248a19c6149c0634e63a5b836e Author: Jiayu Liu Date: Thu May 13 18:51:14 2021 +0800 add window expr --- .../src/physical_plan/expressions/mod.rs | 2 + .../physical_plan/expressions/nth_value.rs | 223 ++++++++++++++++++ datafusion/src/physical_plan/windows.rs | 180 ++++++++++---- datafusion/tests/sql.rs | 17 +- 4 files changed, 371 insertions(+), 51 deletions(-) create mode 100644 datafusion/src/physical_plan/expressions/nth_value.rs diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 803870f3f7840..77da95c3a04a3 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -40,6 +40,7 @@ mod literal; mod min_max; mod negative; mod not; +mod nth_value; mod nullif; mod row_number; mod sum; @@ -58,6 +59,7 @@ pub use literal::{lit, Literal}; pub use min_max::{Max, Min}; pub use negative::{negative, NegativeExpr}; pub use not::{not, NotExpr}; +pub use nth_value::{FirstValue, LastValue, NthValue}; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use row_number::RowNumber; pub use sum::{sum_return_type, Sum}; diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs b/datafusion/src/physical_plan/expressions/nth_value.rs new file mode 100644 index 0000000000000..e90ad322aae9d --- /dev/null +++ b/datafusion/src/physical_plan/expressions/nth_value.rs @@ -0,0 +1,223 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{ + window_functions::BuiltInWindowFunctionExpr, PhysicalExpr, WindowAccumulator, +}; +use crate::scalar::ScalarValue; +use arrow::datatypes::{DataType, Field}; +use std::any::Any; +use std::convert::TryFrom; +use std::sync::Arc; + +/// first_value expression +#[derive(Debug)] +pub struct FirstValue { + name: String, + data_type: DataType, + expr: Arc, +} + +impl FirstValue { + /// Create a new FIRST_VALUE window aggregate function + pub fn new(expr: Arc, name: String, data_type: DataType) -> Self { + Self { + name, + data_type, + expr, + } + } +} + +impl BuiltInWindowFunctionExpr for FirstValue { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + let nullable = true; + Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(NthValueAccumulator::try_new( + 1, + self.data_type.clone(), + )?)) + } +} + +// sql values start with 1, so we can use 0 to indicate the special last value behavior +const SPECIAL_SIZE_VALUE_FOR_LAST: u32 = 0; + +/// last_value expression +#[derive(Debug)] +pub struct LastValue { + name: String, + data_type: DataType, + expr: Arc, +} + +impl LastValue { + /// Create a new FIRST_VALUE window aggregate function + pub fn new(expr: Arc, name: String, data_type: DataType) -> Self { + Self { + name, + data_type, + expr, + } + } +} + +impl BuiltInWindowFunctionExpr for LastValue { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + let nullable = true; + Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(NthValueAccumulator::try_new( + SPECIAL_SIZE_VALUE_FOR_LAST, + self.data_type.clone(), + )?)) + } +} + +/// nth_value expression +#[derive(Debug)] +pub struct NthValue { + name: String, + n: u32, + data_type: DataType, + expr: Arc, +} + +impl NthValue { + /// Create a new NTH_VALUE window aggregate function + pub fn try_new( + expr: Arc, + name: String, + n: u32, + data_type: DataType, + ) -> Result { + if n == SPECIAL_SIZE_VALUE_FOR_LAST { + Err(DataFusionError::Execution( + "nth_value expect n to be > 0".to_owned(), + )) + } else { + Ok(Self { + name, + n, + data_type, + expr, + }) + } + } +} + +impl BuiltInWindowFunctionExpr for NthValue { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + let nullable = true; + Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(NthValueAccumulator::try_new( + self.n, + self.data_type.clone(), + )?)) + } +} + +#[derive(Debug)] +struct NthValueAccumulator { + // n the target nth_value, however we'll reuse it for last_value acc, so when n == 0 it specifically + // means last; also note that it is totally valid for n to be larger than the number of rows input + // in which case all the values shall be null + n: u32, + offset: u32, + value: ScalarValue, +} + +impl NthValueAccumulator { + /// new count accumulator + pub fn try_new(n: u32, data_type: DataType) -> Result { + Ok(Self { + n, + offset: 0, + // null value of that data_type by default + value: ScalarValue::try_from(&data_type)?, + }) + } +} + +impl WindowAccumulator for NthValueAccumulator { + fn scan(&mut self, values: &[ScalarValue]) -> Result> { + if self.n == SPECIAL_SIZE_VALUE_FOR_LAST { + // for last_value function + self.value = values[0].clone(); + } else if self.offset < self.n { + self.offset += 1; + if self.offset == self.n { + self.value = values[0].clone(); + } + } + Ok(None) + } + + fn evaluate(&self) -> Result> { + Ok(Some(self.value.clone())) + } +} diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index 8ced3aec8ec11..e790eeaca749e 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -20,7 +20,9 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ aggregates, - expressions::RowNumber, + expressions::{FirstValue, LastValue, Literal, NthValue, RowNumber}, + type_coercion::coerce, + window_functions::signature_for_built_in, window_functions::BuiltInWindowFunctionExpr, window_functions::{BuiltInWindowFunction, WindowFunction}, Accumulator, AggregateExpr, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, @@ -39,6 +41,7 @@ use futures::stream::{Stream, StreamExt}; use futures::Future; use pin_project_lite::pin_project; use std::any::Any; +use std::convert::TryInto; use std::iter; use std::pin::Pin; use std::sync::Arc; @@ -82,12 +85,40 @@ pub fn create_window_expr( fn create_built_in_window_expr( fun: &BuiltInWindowFunction, - _args: &[Arc], - _input_schema: &Schema, + args: &[Arc], + input_schema: &Schema, name: String, ) -> Result> { match fun { BuiltInWindowFunction::RowNumber => Ok(Arc::new(RowNumber::new(name))), + BuiltInWindowFunction::NthValue => { + let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?; + let arg = coerced_args[0].clone(); + let n = coerced_args[1] + .as_any() + .downcast_ref::() + .unwrap() + .value(); + let n: i64 = n + .clone() + .try_into() + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; + let n: u32 = n as u32; + let data_type = args[0].data_type(input_schema)?; + Ok(Arc::new(NthValue::try_new(arg, name, n, data_type)?)) + } + BuiltInWindowFunction::FirstValue => { + let arg = + coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone(); + let data_type = args[0].data_type(input_schema)?; + Ok(Arc::new(FirstValue::new(arg, name, data_type))) + } + BuiltInWindowFunction::LastValue => { + let arg = + coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone(); + let data_type = args[0].data_type(input_schema)?; + Ok(Arc::new(LastValue::new(arg, name, data_type))) + } _ => Err(DataFusionError::NotImplemented(format!( "Window function with {:?} not yet implemented", fun @@ -484,45 +515,106 @@ impl RecordBatchStream for WindowAggStream { #[cfg(test)] mod tests { - // use super::*; - - // /// some mock data to test windows - // fn some_data() -> (Arc, Vec) { - // // define a schema. - // let schema = Arc::new(Schema::new(vec![ - // Field::new("a", DataType::UInt32, false), - // Field::new("b", DataType::Float64, false), - // ])); - - // // define data. - // ( - // schema.clone(), - // vec![ - // RecordBatch::try_new( - // schema.clone(), - // vec![ - // Arc::new(UInt32Array::from(vec![2, 3, 4, 4])), - // Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), - // ], - // ) - // .unwrap(), - // RecordBatch::try_new( - // schema, - // vec![ - // Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), - // Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), - // ], - // ) - // .unwrap(), - // ], - // ) - // } - - // #[tokio::test] - // async fn window_function() -> Result<()> { - // let input: Arc = unimplemented!(); - // let input_schema = input.schema(); - // let window_expr = vec![]; - // WindowAggExec::try_new(window_expr, input, input_schema); - // } + use super::*; + use crate::physical_plan::aggregates::AggregateFunction; + use crate::physical_plan::collect; + use crate::physical_plan::csv::{CsvExec, CsvReadOptions}; + use crate::physical_plan::expressions::col; + use crate::test; + use arrow::array::*; + + fn create_test_schema(partitions: usize) -> Result<(Arc, SchemaRef)> { + let schema = test::aggr_test_schema(); + let path = test::create_partitioned_csv("aggregate_test_100.csv", partitions)?; + let csv = CsvExec::try_new( + &path, + CsvReadOptions::new().schema(&schema), + None, + 1024, + None, + )?; + + let input = Arc::new(csv); + Ok((input, schema)) + } + + #[tokio::test] + async fn window_function_input_partition() -> Result<()> { + let (input, schema) = create_test_schema(4)?; + + let window_exec = Arc::new(WindowAggExec::try_new( + vec![create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Count), + &[col("c3")], + schema.as_ref(), + "count".to_owned(), + )?], + input, + schema.clone(), + )?); + + let result = collect(window_exec).await; + + assert!(result.is_err()); + if let Some(DataFusionError::Internal(msg)) = result.err() { + assert_eq!( + msg, + "WindowAggExec requires a single input partition".to_owned() + ); + } else { + unreachable!("Expect an internal error to happen"); + } + Ok(()) + } + + #[tokio::test] + async fn window_function() -> Result<()> { + let (input, schema) = create_test_schema(1)?; + + let window_exec = Arc::new(WindowAggExec::try_new( + vec![ + create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Count), + &[col("c3")], + schema.as_ref(), + "count".to_owned(), + )?, + create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Max), + &[col("c3")], + schema.as_ref(), + "max".to_owned(), + )?, + create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Min), + &[col("c3")], + schema.as_ref(), + "min".to_owned(), + )?, + ], + input, + schema.clone(), + )?); + + let result: Vec = collect(window_exec).await?; + assert_eq!(result.len(), 1); + + let columns = result[0].columns(); + + // c3 is small int + + let count: &UInt64Array = as_primitive_array(&columns[0]); + assert_eq!(count.value(0), 100); + assert_eq!(count.value(99), 100); + + let max: &Int8Array = as_primitive_array(&columns[1]); + assert_eq!(max.value(0), 125); + assert_eq!(max.value(99), 125); + + let min: &Int8Array = as_primitive_array(&columns[2]); + assert_eq!(min.value(0), -117); + assert_eq!(min.value(99), -117); + + Ok(()) + } } diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 55bc88eedf9ab..f5b416f789736 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -807,17 +807,20 @@ async fn csv_query_window_with_empty_over() -> Result<()> { avg(c3) over (), \ count(c3) over (), \ max(c3) over (), \ - min(c3) over () \ + min(c3) over (), \ + first_value(c3) over (), \ + last_value(c3) over (), \ + nth_value(c3, 2) over () from aggregate_test_100 \ - order by c2 \ + order by c2 limit 5"; let actual = execute(&mut ctx, sql).await; let expected = vec![ - vec!["1", "781", "7.81", "100", "125", "-117"], - vec!["1", "781", "7.81", "100", "125", "-117"], - vec!["1", "781", "7.81", "100", "125", "-117"], - vec!["1", "781", "7.81", "100", "125", "-117"], - vec!["1", "781", "7.81", "100", "125", "-117"], + vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"], + vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"], + vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"], + vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"], + vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"], ]; assert_eq!(expected, actual); Ok(())