From b6c6a3c896e4986db35d8ef98a79c44d480b4c5b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 27 Jul 2021 13:20:12 -0400 Subject: [PATCH 1/2] Add support for group by hash of a null column, tests for same --- .../src/physical_plan/hash_aggregate.rs | 60 +++++++++- datafusion/src/scalar.rs | 37 +++++- datafusion/tests/sql.rs | 110 ++++++++++++++++++ 3 files changed, 202 insertions(+), 5 deletions(-) diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index eb4a356e88ce8..7094f81874d5b 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -395,7 +395,10 @@ fn group_aggregate_batch( // We can safely unwrap here as we checked we can create an accumulator before let accumulator_set = create_accumulators(aggr_expr).unwrap(); batch_keys.push(key.clone()); - let _ = create_group_by_values(&group_values, row, &mut group_by_values); + // Note it would be nice to make this a real error (rather than panic) + // but it is better than silently ignoring the issue and getting wrong results + create_group_by_values(&group_values, row, &mut group_by_values) + .expect("can not create group by value"); ( key.clone(), (group_by_values.clone(), accumulator_set, vec![row as u32]), @@ -508,7 +511,9 @@ fn dictionary_create_key_for_col( } /// Appends a sequence of [u8] bytes for the value in `col[row]` to -/// `vec` to be used as a key into the hash map +/// `vec` to be used as a key into the hash map. +/// +/// NOTE: This functon does not check col.is_valid(). Caller must do so fn create_key_for_col(col: &ArrayRef, row: usize, vec: &mut Vec) -> Result<()> { match col.data_type() { DataType::Boolean => { @@ -640,6 +645,50 @@ fn create_key_for_col(col: &ArrayRef, row: usize, vec: &mut Vec) -> Result<( } /// Create a key `Vec` that is used as key for the hashmap +/// +/// This looks like +/// [null_byte][col_value_bytes][null_byte][col_value_bytes] +/// +/// Note that relatively uncommon patterns (e.g. not 0x00) are chosen +/// for the null_byte to make debugging easier. The actual values are +/// arbitrary. +/// +/// For a NULL value in a column, the key looks like +/// [0xFE] +/// +/// For a Non-NULL value in a column, this looks like: +/// [0xFF][byte representation of column value] +/// +/// Example of a key with no NULL values: +/// ```text +/// 0xFF byte at the start of each column +/// signifies the value is non-null +/// │ +/// +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ─ ─ ┐ +/// +/// │ string len │ 0x1234 +/// { ▼ (as usize le) "foo" ▼(as u16 le) +/// k1: "foo" ╔ ═┌──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──╦ ═┌──┬──┐ +/// k2: 0x1234u16 FF║03│00│00│00│00│00│00│00│"f│"o│"o│FF║34│12│ +/// } ╚ ═└──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──╩ ═└──┴──┘ +/// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 +/// ``` +/// +/// Example of a key with NULL values: +/// +///```text +/// 0xFE byte at the start of k1 column +/// ┌ ─ signifies the value is NULL +/// +/// └ ┐ +/// 0x1234 +/// { ▼ (as u16 le) +/// k1: NULL ╔ ═╔ ═┌──┬──┐ +/// k2: 0x1234u16 FE║FF║12│34│ +/// } ╚ ═╚ ═└──┴──┘ +/// 0 1 2 3 +///``` pub(crate) fn create_key( group_by_keys: &[ArrayRef], row: usize, @@ -647,7 +696,12 @@ pub(crate) fn create_key( ) -> Result<()> { vec.clear(); for col in group_by_keys { - create_key_for_col(col, row, vec)? + if !col.is_valid(row) { + vec.push(0xFE); + } else { + vec.push(0xFF); + create_key_for_col(col, row, vec)? + } } Ok(()) } diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 8efea63e82368..90c9bf7369d4a 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -28,7 +28,7 @@ use arrow::{ }, }; use ordered_float::OrderedFloat; -use std::convert::Infallible; +use std::convert::{Infallible, TryInto}; use std::str::FromStr; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; @@ -796,6 +796,11 @@ impl ScalarValue { /// Converts a value in `array` at `index` into a ScalarValue pub fn try_from_array(array: &ArrayRef, index: usize) -> Result { + // handle NULL value + if !array.is_valid(index) { + return array.data_type().try_into(); + } + Ok(match array.data_type() { DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), @@ -897,6 +902,7 @@ impl ScalarValue { let dict_array = array.as_any().downcast_ref::>().unwrap(); // look up the index in the values dictionary + // (note validity was previously checked in `try_from_array`) let keys_col = dict_array.keys(); let values_index = keys_col.value(index).to_usize().ok_or_else(|| { DataFusionError::Internal(format!( @@ -1132,6 +1138,7 @@ impl_try_from!(Boolean, bool); impl TryFrom<&DataType> for ScalarValue { type Error = DataFusionError; + /// Create a Null instance of ScalarValue for this datatype fn try_from(datatype: &DataType) -> Result { Ok(match datatype { DataType::Boolean => ScalarValue::Boolean(None), @@ -1161,12 +1168,15 @@ impl TryFrom<&DataType> for ScalarValue { DataType::Timestamp(TimeUnit::Nanosecond, _) => { ScalarValue::TimestampNanosecond(None) } + DataType::Dictionary(_index_type, value_type) => { + value_type.as_ref().try_into()? + } DataType::List(ref nested_type) => { ScalarValue::List(None, Box::new(nested_type.data_type().clone())) } _ => { return Err(DataFusionError::NotImplemented(format!( - "Can't create a scalar of type \"{:?}\"", + "Can't create a scalar from data_type \"{:?}\"", datatype ))) } @@ -1535,6 +1545,29 @@ mod tests { "{}", result); } + #[test] + fn scalar_try_from_array_null() { + let array = vec![Some(33), None].into_iter().collect::(); + let array: ArrayRef = Arc::new(array); + + assert_eq!( + ScalarValue::Int64(Some(33)), + ScalarValue::try_from_array(&array, 0).unwrap() + ); + assert_eq!( + ScalarValue::Int64(None), + ScalarValue::try_from_array(&array, 1).unwrap() + ); + } + + #[test] + fn scalar_try_from_dict_datatype() { + let data_type = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); + let data_type = &data_type; + assert_eq!(ScalarValue::Utf8(None), data_type.try_into().unwrap()) + } + #[test] fn size_of_scalar() { // Since ScalarValues are used in a non trivial number of places, diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index bfe2f2fc49138..f3eed9c7bf67d 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -3014,6 +3014,109 @@ async fn query_count_distinct() -> Result<()> { Ok(()) } +#[tokio::test] +async fn query_group_on_null() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![ + Some(0), + Some(3), + None, + Some(1), + Some(3), + ]))], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + let sql = "SELECT COUNT(*), c1 FROM test GROUP BY c1"; + + let actual = execute_to_batches(&mut ctx, sql).await; + + // Note that the results also + // include a row for NULL (c1=NULL, count = 1) + let expected = vec![ + "+-----------------+----+", + "| COUNT(UInt8(1)) | c1 |", + "+-----------------+----+", + "| 1 | |", + "| 1 | 0 |", + "| 1 | 1 |", + "| 2 | 3 |", + "+-----------------+----+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn query_group_on_null_multi_col() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Utf8, true), + ])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![ + Some(0), + Some(0), + Some(3), + None, + None, + Some(3), + Some(0), + None, + Some(3), + ])), + Arc::new(StringArray::from(vec![ + None, + None, + Some("foo"), + None, + Some("bar"), + Some("foo"), + None, + Some("bar"), + Some("foo"), + ])), + ], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c1, c2"; + + let actual = execute_to_batches(&mut ctx, sql).await; + + // Note that the results also include values for null + // include a row for NULL (c1=NULL, count = 1) + let expected = vec![ + "+-----------------+----+-----+", + "| COUNT(UInt8(1)) | c1 | c2 |", + "+-----------------+----+-----+", + "| 1 | | |", + "| 2 | | bar |", + "| 3 | 0 | |", + "| 3 | 3 | foo |", + "+-----------------+----+-----+", + ]; + assert_batches_sorted_eq!(expected, &actual); + + // Also run query with group columns reversed (results shoudl be the same) + let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c2, c1"; + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn query_on_string_dictionary() -> Result<()> { // Test to ensure DataFusion can operate on dictionary types @@ -3067,6 +3170,13 @@ async fn query_on_string_dictionary() -> Result<()> { let expected = vec![vec!["2"]]; assert_eq!(expected, actual); + // grouping + let sql = "SELECT d1, COUNT(*) FROM test group by d1"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["NULL", "1"], vec!["one", "1"], vec!["three", "1"]]; + assert_eq!(expected, actual); + Ok(()) } From 3378cef7c58cf86201604ce63b6878260b1dbd81 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 30 Jul 2021 15:57:37 -0400 Subject: [PATCH 2/2] Update datafusion/src/physical_plan/hash_aggregate.rs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Daniël Heres --- datafusion/src/physical_plan/hash_aggregate.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index 7094f81874d5b..5c3c57695d0f0 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -513,7 +513,7 @@ fn dictionary_create_key_for_col( /// Appends a sequence of [u8] bytes for the value in `col[row]` to /// `vec` to be used as a key into the hash map. /// -/// NOTE: This functon does not check col.is_valid(). Caller must do so +/// NOTE: This function does not check col.is_valid(). Caller must do so fn create_key_for_col(col: &ArrayRef, row: usize, vec: &mut Vec) -> Result<()> { match col.data_type() { DataType::Boolean => {