From 5b68cbbabf2c51d1e6831e9f40cd61c4bfbddc0c Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Mon, 28 Jun 2021 09:51:46 +0800 Subject: [PATCH] add integration tests for rank, dense_rank --- datafusion/src/execution/context.rs | 24 ++--- .../physical_plan/expressions/nth_value.rs | 94 ++++++++++++++----- datafusion/tests/sql.rs | 21 ++--- .../sqls/simple_window_built_in_functions.sql | 27 ++++++ .../sqls/simple_window_full_aggregation.sql | 2 +- .../simple_window_ordered_aggregation.sql | 2 +- .../simple_window_partition_aggregation.sql | 2 +- ...ple_window_partition_order_aggregation.sql | 2 +- ...imple_window_ranked_built_in_functions.sql | 22 +++++ integration-tests/test_psql_parity.py | 4 +- 10 files changed, 146 insertions(+), 54 deletions(-) create mode 100644 integration-tests/sqls/simple_window_built_in_functions.sql create mode 100644 integration-tests/sqls/simple_window_ranked_built_in_functions.sql diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 5c41ed26eea43..5df8e20ea6060 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1335,11 +1335,11 @@ mod tests { "+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+", "| c1 | c2 | ROW_NUMBER() | FIRST_VALUE(c2) | LAST_VALUE(c2) | NTH_VALUE(c2,Int64(2)) | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |", "+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+", - "| 0 | 1 | 1 | 1 | 10 | 2 | 1 | 1 | 1 | 1 | 1 |", - "| 0 | 2 | 2 | 1 | 10 | 2 | 3 | 2 | 2 | 1 | 1.5 |", - "| 0 | 3 | 3 | 1 | 10 | 2 | 6 | 3 | 3 | 1 | 2 |", - "| 0 | 4 | 4 | 1 | 10 | 2 | 10 | 4 | 4 | 1 | 2.5 |", - "| 0 | 5 | 5 | 1 | 10 | 2 | 15 | 5 | 5 | 1 | 3 |", + "| 0 | 1 | 1 | 1 | 1 | | 1 | 1 | 1 | 1 | 1 |", + "| 0 | 2 | 2 | 1 | 2 | 2 | 3 | 2 | 2 | 1 | 1.5 |", + "| 0 | 3 | 3 | 1 | 3 | 2 | 6 | 3 | 3 | 1 | 2 |", + "| 0 | 4 | 4 | 1 | 4 | 2 | 10 | 4 | 4 | 1 | 2.5 |", + "| 0 | 5 | 5 | 1 | 5 | 2 | 15 | 5 | 5 | 1 | 3 |", "+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+", ]; @@ -1392,7 +1392,7 @@ mod tests { ROW_NUMBER() OVER (PARTITION BY c2 ORDER BY c1), \ FIRST_VALUE(c2 + c1) OVER (PARTITION BY c2 ORDER BY c1), \ LAST_VALUE(c2 + c1) OVER (PARTITION BY c2 ORDER BY c1), \ - NTH_VALUE(c2 + c1, 2) OVER (PARTITION BY c2 ORDER BY c1), \ + NTH_VALUE(c2 + c1, 1) OVER (PARTITION BY c2 ORDER BY c1), \ SUM(c2) OVER (PARTITION BY c2 ORDER BY c1), \ COUNT(c2) OVER (PARTITION BY c2 ORDER BY c1), \ MAX(c2) OVER (PARTITION BY c2 ORDER BY c1), \ @@ -1407,13 +1407,13 @@ mod tests { let expected = vec![ "+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+", - "| c1 | c2 | ROW_NUMBER() | FIRST_VALUE(c2 Plus c1) | LAST_VALUE(c2 Plus c1) | NTH_VALUE(c2 Plus c1,Int64(2)) | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |", + "| c1 | c2 | ROW_NUMBER() | FIRST_VALUE(c2 Plus c1) | LAST_VALUE(c2 Plus c1) | NTH_VALUE(c2 Plus c1,Int64(1)) | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |", "+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+", - "| 0 | 1 | 1 | 1 | 4 | 2 | 1 | 1 | 1 | 1 | 1 |", - "| 0 | 2 | 1 | 2 | 5 | 3 | 2 | 1 | 2 | 2 | 2 |", - "| 0 | 3 | 1 | 3 | 6 | 4 | 3 | 1 | 3 | 3 | 3 |", - "| 0 | 4 | 1 | 4 | 7 | 5 | 4 | 1 | 4 | 4 | 4 |", - "| 0 | 5 | 1 | 5 | 8 | 6 | 5 | 1 | 5 | 5 | 5 |", + "| 0 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |", + "| 0 | 2 | 1 | 2 | 2 | 2 | 2 | 1 | 2 | 2 | 2 |", + "| 0 | 3 | 1 | 3 | 3 | 3 | 3 | 1 | 3 | 3 | 3 |", + "| 0 | 4 | 1 | 4 | 4 | 4 | 4 | 1 | 4 | 4 | 4 |", + "| 0 | 5 | 1 | 5 | 5 | 5 | 5 | 1 | 5 | 5 | 5 |", "+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+", ]; diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs b/datafusion/src/physical_plan/expressions/nth_value.rs index 3897ae5cb53e0..854078e232f00 100644 --- a/datafusion/src/physical_plan/expressions/nth_value.rs +++ b/datafusion/src/physical_plan/expressions/nth_value.rs @@ -22,9 +22,11 @@ use crate::physical_plan::window_functions::PartitionEvaluator; use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr}; use crate::scalar::ScalarValue; use arrow::array::{new_null_array, ArrayRef}; +use arrow::compute::kernels::window::shift; use arrow::datatypes::{DataType, Field}; use arrow::record_batch::RecordBatch; use std::any::Any; +use std::iter; use std::ops::Range; use std::sync::Arc; @@ -138,21 +140,56 @@ pub(crate) struct NthValueEvaluator { } impl PartitionEvaluator for NthValueEvaluator { - fn evaluate_partition(&self, partition: Range) -> Result { - let value = &self.values[0]; + fn include_rank(&self) -> bool { + true + } + + fn evaluate_partition(&self, _partition: Range) -> Result { + unreachable!("first, last, and nth_value evaluation must be called with evaluate_partition_with_rank") + } + + fn evaluate_partition_with_rank( + &self, + partition: Range, + ranks_in_partition: &[Range], + ) -> Result { + let arr = &self.values[0]; let num_rows = partition.end - partition.start; - let value = value.slice(partition.start, num_rows); - let index: usize = match self.kind { - NthValueKind::First => 0, - NthValueKind::Last => (num_rows as usize) - 1, - NthValueKind::Nth(n) => (n as usize) - 1, - }; - Ok(if index >= num_rows { - new_null_array(value.data_type(), num_rows) - } else { - let value = ScalarValue::try_from_array(&value, index)?; - value.to_array_of_size(num_rows) - }) + match self.kind { + NthValueKind::First => { + let value = ScalarValue::try_from_array(arr, partition.start)?; + Ok(value.to_array_of_size(num_rows)) + } + NthValueKind::Last => { + // because the default window frame is between unbounded preceding and current + // row with peer evaluation, hence the last rows expands until the end of the peers + let values = ranks_in_partition + .iter() + .map(|range| { + let len = range.end - range.start; + let value = ScalarValue::try_from_array(arr, range.end - 1)?; + Ok(iter::repeat(value).take(len)) + }) + .collect::>>()? + .into_iter() + .flatten(); + ScalarValue::iter_to_array(values) + } + NthValueKind::Nth(n) => { + let index = (n as usize) - 1; + if index >= num_rows { + Ok(new_null_array(arr.data_type(), num_rows)) + } else { + let value = + ScalarValue::try_from_array(arr, partition.start + index)?; + let arr = value.to_array_of_size(num_rows); + // because the default window frame is between unbounded preceding and current + // row, hence the shift because for values with indices < index they should be + // null. This changes when window frames other than default is implemented + shift(arr.as_ref(), index as i64).map_err(DataFusionError::ArrowError) + } + } + } } } @@ -164,16 +201,17 @@ mod tests { use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; - fn test_i32_result(expr: NthValue, expected: Vec) -> Result<()> { + fn test_i32_result(expr: NthValue, expected: Int32Array) -> Result<()> { let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8])); let values = vec![arr]; let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; - let result = expr.create_evaluator(&batch)?.evaluate(vec![0..8])?; + let result = expr + .create_evaluator(&batch)? + .evaluate_with_rank(vec![0..8], vec![0..8])?; assert_eq!(1, result.len()); let result = result[0].as_any().downcast_ref::().unwrap(); - let result = result.values(); - assert_eq!(expected, result); + assert_eq!(expected, *result); Ok(()) } @@ -184,7 +222,7 @@ mod tests { Arc::new(Column::new("arr", 0)), DataType::Int32, ); - test_i32_result(first_value, vec![1; 8])?; + test_i32_result(first_value, Int32Array::from_iter_values(vec![1; 8]))?; Ok(()) } @@ -195,7 +233,7 @@ mod tests { Arc::new(Column::new("arr", 0)), DataType::Int32, ); - test_i32_result(last_value, vec![8; 8])?; + test_i32_result(last_value, Int32Array::from_iter_values(vec![8; 8]))?; Ok(()) } @@ -207,7 +245,7 @@ mod tests { DataType::Int32, 1, )?; - test_i32_result(nth_value, vec![1; 8])?; + test_i32_result(nth_value, Int32Array::from_iter_values(vec![1; 8]))?; Ok(()) } @@ -219,7 +257,19 @@ mod tests { DataType::Int32, 2, )?; - test_i32_result(nth_value, vec![-2; 8])?; + test_i32_result( + nth_value, + Int32Array::from(vec![ + None, + Some(-2), + Some(-2), + Some(-2), + Some(-2), + Some(-2), + Some(-2), + Some(-2), + ]), + )?; Ok(()) } } diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index c06a4bb1462ee..5cb5529ba80e7 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -903,7 +903,7 @@ async fn csv_query_window_with_partition_by() -> Result<()> { "-21481", "-16974", "-21481", - "-21481", + "NULL", ], vec![ "141680161", @@ -952,15 +952,8 @@ async fn csv_query_window_with_order_by() -> Result<()> { let actual = execute(&mut ctx, sql).await; let expected = vec![ vec![ - "28774375", - "61035129", - "61035129", - "1", - "61035129", - "61035129", - "61035129", - "2025611582", - "-108973366", + "28774375", "61035129", "61035129", "1", "61035129", "61035129", "61035129", + "61035129", "NULL", ], vec![ "63044568", @@ -970,7 +963,7 @@ async fn csv_query_window_with_order_by() -> Result<()> { "61035129", "-108973366", "61035129", - "2025611582", + "-108973366", "-108973366", ], vec![ @@ -981,7 +974,7 @@ async fn csv_query_window_with_order_by() -> Result<()> { "623103518", "-108973366", "61035129", - "2025611582", + "623103518", "-108973366", ], vec![ @@ -992,7 +985,7 @@ async fn csv_query_window_with_order_by() -> Result<()> { "623103518", "-1927628110", "61035129", - "2025611582", + "-1927628110", "-108973366", ], vec![ @@ -1003,7 +996,7 @@ async fn csv_query_window_with_order_by() -> Result<()> { "623103518", "-1927628110", "61035129", - "2025611582", + "-1899175111", "-108973366", ], ]; diff --git a/integration-tests/sqls/simple_window_built_in_functions.sql b/integration-tests/sqls/simple_window_built_in_functions.sql new file mode 100644 index 0000000000000..e76b383060026 --- /dev/null +++ b/integration-tests/sqls/simple_window_built_in_functions.sql @@ -0,0 +1,27 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at + +-- http://www.apache.org/licenses/LICENSE-2.0 + +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +SELECT + c9, + row_number() OVER (ORDER BY c9) row_num, + first_value(c9) OVER (ORDER BY c9) first_c9, + first_value(c9) OVER (ORDER BY c9 DESC) first_c9_desc, + last_value(c9) OVER (ORDER BY c9) last_c9, + last_value(c9) OVER (ORDER BY c9 DESC) last_c9_desc, + nth_value(c9, 2) OVER (ORDER BY c9) second_c9, + nth_value(c9, 2) OVER (ORDER BY c9 DESC) second_c9_desc +FROM test +ORDER BY c9; diff --git a/integration-tests/sqls/simple_window_full_aggregation.sql b/integration-tests/sqls/simple_window_full_aggregation.sql index 94860bc3b1835..7346f67fa4ba4 100644 --- a/integration-tests/sqls/simple_window_full_aggregation.sql +++ b/integration-tests/sqls/simple_window_full_aggregation.sql @@ -11,7 +11,7 @@ -- Unless required by applicable law or agreed to in writing, software -- distributed under the License is distributed on an "AS IS" BASIS, -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --- See the License for the specific language gOVERning permissions and +-- See the License for the specific language governing permissions and -- limitations under the License. SELECT diff --git a/integration-tests/sqls/simple_window_ordered_aggregation.sql b/integration-tests/sqls/simple_window_ordered_aggregation.sql index d9f467b0cb09a..567c1881a3db6 100644 --- a/integration-tests/sqls/simple_window_ordered_aggregation.sql +++ b/integration-tests/sqls/simple_window_ordered_aggregation.sql @@ -11,7 +11,7 @@ -- Unless required by applicable law or agreed to in writing, software -- distributed under the License is distributed on an "AS IS" BASIS, -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --- See the License for the specific language gOVERning permissions and +-- See the License for the specific language governing permissions and -- limitations under the License. SELECT diff --git a/integration-tests/sqls/simple_window_partition_aggregation.sql b/integration-tests/sqls/simple_window_partition_aggregation.sql index f395671db8cc8..bac4e465f626b 100644 --- a/integration-tests/sqls/simple_window_partition_aggregation.sql +++ b/integration-tests/sqls/simple_window_partition_aggregation.sql @@ -11,7 +11,7 @@ -- Unless required by applicable law or agreed to in writing, software -- distributed under the License is distributed on an "AS IS" BASIS, -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --- See the License for the specific language gOVERning permissions and +-- See the License for the specific language governing permissions and -- limitations under the License. SELECT diff --git a/integration-tests/sqls/simple_window_partition_order_aggregation.sql b/integration-tests/sqls/simple_window_partition_order_aggregation.sql index a11a9ec6e4b1e..2702c0e2e0326 100644 --- a/integration-tests/sqls/simple_window_partition_order_aggregation.sql +++ b/integration-tests/sqls/simple_window_partition_order_aggregation.sql @@ -11,7 +11,7 @@ -- Unless required by applicable law or agreed to in writing, software -- distributed under the License is distributed on an "AS IS" BASIS, -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --- See the License for the specific language gOVERning permissions and +-- See the License for the specific language governing permissions and -- limitations under the License. SELECT diff --git a/integration-tests/sqls/simple_window_ranked_built_in_functions.sql b/integration-tests/sqls/simple_window_ranked_built_in_functions.sql new file mode 100644 index 0000000000000..0ea6b042555cc --- /dev/null +++ b/integration-tests/sqls/simple_window_ranked_built_in_functions.sql @@ -0,0 +1,22 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at + +-- http://www.apache.org/licenses/LICENSE-2.0 + +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +select + c9, + rank() OVER (PARTITION BY c2 ORDER BY c3) rank_by_c3, + dense_rank() OVER (PARTITION BY c2 ORDER BY c3) dense_rank_by_c3 +FROM test +ORDER BY c9; diff --git a/integration-tests/test_psql_parity.py b/integration-tests/test_psql_parity.py index 92670bed0c4dd..2bb8da9fd5c58 100644 --- a/integration-tests/test_psql_parity.py +++ b/integration-tests/test_psql_parity.py @@ -74,7 +74,7 @@ class PsqlParityTest(unittest.TestCase): def test_parity(self): root = Path(os.path.dirname(__file__)) / "sqls" files = set(root.glob("*.sql")) - self.assertEqual(len(files), 9, msg="tests are missed") + self.assertEqual(len(files), 11, msg="tests are missed") for fname in files: with self.subTest(fname=fname): datafusion_output = pd.read_csv( @@ -82,7 +82,7 @@ def test_parity(self): ) psql_output = pd.read_csv(io.BytesIO(generate_csv_from_psql(fname))) self.assertTrue( - np.allclose(datafusion_output, psql_output), + np.allclose(datafusion_output, psql_output, equal_nan=True), msg=f"datafusion output=\n{datafusion_output}, psql_output=\n{psql_output}", )