From c3bf2428acabb0ed2a0e5898c6fff7f4ad24b55c Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Sat, 30 Oct 2021 11:07:20 +0200 Subject: [PATCH 1/2] Improve GetIndexedFieldExpr adding utf8 key based access for struct values --- datafusion/src/field_util.rs | 21 ++- datafusion/src/logical_plan/expr.rs | 2 +- .../expressions/get_indexed_field.rs | 165 ++++++++++++++++-- datafusion/tests/sql.rs | 44 +++++ 4 files changed, 218 insertions(+), 14 deletions(-) diff --git a/datafusion/src/field_util.rs b/datafusion/src/field_util.rs index 9d5facebc0c1f..272c17b60887b 100644 --- a/datafusion/src/field_util.rs +++ b/datafusion/src/field_util.rs @@ -22,7 +22,7 @@ use arrow::datatypes::{DataType, Field}; use crate::error::{DataFusionError, Result}; use crate::scalar::ScalarValue; -/// Returns the field access indexed by `key` from a [`DataType::List`] +/// Returns the field access indexed by `key` from a [`DataType::List`] or [`DataType::Struct`] /// # Error /// Errors if /// * the `data_type` is not a Struct or, @@ -39,6 +39,25 @@ pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { + if s.is_empty() { + Err(DataFusionError::Plan( + "Struct based indexed access requires a non empty string".to_string(), + )) + } else { + let field = fields.iter().find(|f| f.name() == s); + match field { + None => Err(DataFusionError::Plan(format!( + "Field {} not found in struct", + s + ))), + Some(f) => Ok(f.clone()), + } + } + } + (DataType::Struct(_), _) => Err(DataFusionError::Plan( + "Only utf8 strings are valid as an indexed field in a struct".to_string(), + )), (DataType::List(_), _) => Err(DataFusionError::Plan( "Only ints are valid as an indexed field in a list".to_string(), )), diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 499a8c720dba9..9b84dfc8df727 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -246,7 +246,7 @@ pub enum Expr { IsNull(Box), /// arithmetic negation of an expression, the operand must be of a signed numeric data type Negative(Box), - /// Returns the field of a [`ListArray`] by key + /// Returns the field of a [`ListArray`] or [`StructArray`] by key GetIndexedField { /// the expression to take the field from expr: Box, diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs b/datafusion/src/physical_plan/expressions/get_indexed_field.rs index 8a9191e9c346d..0f004743ec498 100644 --- a/datafusion/src/physical_plan/expressions/get_indexed_field.rs +++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs @@ -34,7 +34,7 @@ use crate::{ field_util::get_indexed_field as get_data_type_field, physical_plan::{ColumnarValue, PhysicalExpr}, }; -use arrow::array::ListArray; +use arrow::array::{ListArray, StructArray}; use std::fmt::Debug; /// expression to get a field of a struct array. @@ -81,7 +81,7 @@ impl PhysicalExpr for GetIndexedFieldExpr { let arg = self.arg.evaluate(batch)?; match arg { ColumnarValue::Array(array) => match (array.data_type(), &self.key) { - (DataType::List(_), _) if self.key.is_null() => { + (DataType::List(_) | DataType::Struct(_), _) if self.key.is_null() => { let scalar_null: ScalarValue = array.data_type().try_into()?; Ok(ColumnarValue::Scalar(scalar_null)) } @@ -100,6 +100,13 @@ impl PhysicalExpr for GetIndexedFieldExpr { let iter = concat(vec.as_slice()).unwrap(); Ok(ColumnarValue::Array(iter)) } + (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { + let as_struct_array = array.as_any().downcast_ref::().unwrap(); + match as_struct_array.column_by_name(k) { + None => Err(DataFusionError::Execution(format!("get indexed field {} not found in struct", k))), + Some(col) => Ok(ColumnarValue::Array(col.clone())) + } + } (dt, key) => Err(DataFusionError::NotImplemented(format!("get indexed field is only possible on lists with int64 indexes. Tried {} with {} index", dt, key))), }, ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented( @@ -112,18 +119,16 @@ impl PhysicalExpr for GetIndexedFieldExpr { #[cfg(test)] mod tests { use super::*; + use crate::arrow::array::GenericListArray; use crate::error::Result; use crate::physical_plan::expressions::{col, lit}; - use arrow::array::{ListBuilder, StringBuilder}; + use arrow::array::{ + Int64Array, Int64Builder, ListBuilder, StringBuilder, StructArray, StructBuilder, + }; use arrow::{array::StringArray, datatypes::Field}; - fn get_indexed_field_test( - list_of_lists: Vec>>, - index: i64, - expected: Vec>, - ) -> Result<()> { - let schema = list_schema("l"); - let builder = StringBuilder::new(3); + fn build_utf8_lists(list_of_lists: Vec>>) -> GenericListArray { + let builder = StringBuilder::new(list_of_lists.len()); let mut lb = ListBuilder::new(builder); for values in list_of_lists { let builder = lb.values(); @@ -137,9 +142,18 @@ mod tests { lb.append(true).unwrap(); } - let expr = col("l", &schema).unwrap(); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?; + lb.finish() + } + fn get_indexed_field_test( + list_of_lists: Vec>>, + index: i64, + expected: Vec>, + ) -> Result<()> { + let schema = list_schema("l"); + let list_col = build_utf8_lists(list_of_lists); + let expr = col("l", &schema).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(list_col)])?; let key = ScalarValue::Int64(Some(index)); let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); @@ -222,4 +236,131 @@ mod tests { let expr = col("l", &schema).unwrap(); get_indexed_field_test_failure(schema, expr, ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field is only possible on lists with int64 indexes. Tried List(Field { name: \"item\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }) with 0 index") } + + fn build_struct( + fields: Vec, + list_of_tuples: Vec<(Option, Vec>)>, + ) -> StructArray { + let foo_builder = Int64Array::builder(list_of_tuples.len()); + let str_builder = StringBuilder::new(list_of_tuples.len()); + let bar_builder = ListBuilder::new(str_builder); + let mut builder = StructBuilder::new( + fields, + vec![Box::new(foo_builder), Box::new(bar_builder)], + ); + for (int_value, list_value) in list_of_tuples { + let fb = builder.field_builder::(0).unwrap(); + match int_value { + None => fb.append_null(), + Some(v) => fb.append_value(v), + } + .unwrap(); + builder.append(true).unwrap(); + let lb = builder + .field_builder::>(1) + .unwrap(); + for str_value in list_value { + match str_value { + None => lb.values().append_null(), + Some(v) => lb.values().append_value(v), + } + .unwrap(); + } + lb.append(true).unwrap(); + } + builder.finish() + } + + fn get_indexed_field_mixed_test( + list_of_tuples: Vec<(Option, Vec>)>, + expected_strings: Vec>>, + expected_ints: Vec>, + ) -> Result<()> { + let struct_col = "s"; + let fields = vec![ + Field::new("foo", DataType::Int64, true), + Field::new( + "bar", + DataType::List(Box::new(Field::new("item", DataType::Utf8, true))), + true, + ), + ]; + let schema = Schema::new(vec![Field::new( + struct_col, + DataType::Struct(fields.clone()), + true, + )]); + let struct_col = build_struct(fields, list_of_tuples.clone()); + + let struct_col_expr = col("s", &schema).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_col)])?; + + let int_field_key = ScalarValue::Utf8(Some("foo".to_string())); + let get_field_expr = Arc::new(GetIndexedFieldExpr::new( + struct_col_expr.clone(), + int_field_key, + )); + let result = get_field_expr + .evaluate(&batch)? + .into_array(batch.num_rows()); + let result = result + .as_any() + .downcast_ref::() + .expect("failed to downcast to Int64Array"); + let expected = &Int64Array::from(expected_ints); + assert_eq!(expected, result); + + let list_field_key = ScalarValue::Utf8(Some("bar".to_string())); + let get_list_expr = + Arc::new(GetIndexedFieldExpr::new(struct_col_expr, list_field_key)); + let result = get_list_expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = result + .as_any() + .downcast_ref::() + .expect(&format!("failed to downcast to ListArray : {:?}", result)); + let expected = + &build_utf8_lists(list_of_tuples.into_iter().map(|t| t.1).collect()); + assert_eq!(expected, result); + + for (i, expected) in expected_strings.into_iter().enumerate() { + let get_nested_str_expr = Arc::new(GetIndexedFieldExpr::new( + get_list_expr.clone(), + ScalarValue::Int64(Some(i as i64)), + )); + let result = get_nested_str_expr + .evaluate(&batch)? + .into_array(batch.num_rows()); + let result = result + .as_any() + .downcast_ref::() + .expect(&format!("failed to downcast to StringArray : {:?}", result)); + let expected = &StringArray::from(expected); + assert_eq!(expected, result); + } + Ok(()) + } + + #[test] + fn get_indexed_field_struct() -> Result<()> { + let list_of_structs = vec![ + (Some(10), vec![Some("a"), Some("b"), None]), + (Some(15), vec![None, Some("c"), Some("d")]), + (None, vec![Some("e"), None, Some("f")]), + ]; + + let expected_list = vec![ + vec![Some("a"), None, Some("e")], + vec![Some("b"), Some("c"), None], + vec![None, Some("d"), Some("f")], + ]; + + let expected_ints = vec![Some(10), Some(15), None]; + + get_indexed_field_mixed_test( + list_of_structs.clone(), + expected_list, + expected_ints.clone(), + )?; + Ok(()) + } } diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index f1e988814addc..bf5608fb241e2 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -5385,3 +5385,47 @@ async fn query_nested_get_indexed_field() -> Result<()> { assert_eq!(expected, actual); Ok(()) } + +#[tokio::test] +async fn query_nested_get_indexed_field_on_struct() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let nested_dt = DataType::List(Box::new(Field::new("item", DataType::Int64, true))); + // Nested schema of { "some_struct": { "bar": [i64] } } + let struct_fields = vec![Field::new("bar", nested_dt.clone(), true)]; + let schema = Arc::new(Schema::new(vec![Field::new( + "some_struct", + DataType::Struct(struct_fields.clone()), + false, + )])); + + let builder = PrimitiveBuilder::::new(3); + let nested_lb = ListBuilder::new(builder); + let mut sb = StructBuilder::new(struct_fields, vec![Box::new(nested_lb)]); + for int_vec in vec![vec![0, 1, 2, 3], vec![4, 5, 6, 7], vec![8, 9, 10, 11]] { + let lb = sb.field_builder::>(0).unwrap(); + for int in int_vec { + lb.values().append_value(int).unwrap(); + } + lb.append(true).unwrap(); + } + let data = RecordBatch::try_new(schema.clone(), vec![Arc::new(sb.finish())])?; + let table = MemTable::try_new(schema, vec![vec![data]])?; + let table_a = Arc::new(table); + + ctx.register_table("structs", table_a)?; + + // Original column is micros, convert to millis and check timestamp + let sql = "SELECT some_struct[\"bar\"] as l0 FROM structs LIMIT 3"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![ + vec!["[0, 1, 2, 3]"], + vec!["[4, 5, 6, 7]"], + vec!["[8, 9, 10, 11]"], + ]; + assert_eq!(expected, actual); + let sql = "SELECT some_struct[\"bar\"][0] as i0 FROM structs LIMIT 3"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["0"], vec!["4"], vec!["8"]]; + assert_eq!(expected, actual); + Ok(()) +} From 3145cf5da6dd1eac292a1aed3af9054fef327337 Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Sat, 30 Oct 2021 14:35:01 +0200 Subject: [PATCH 2/2] fix clippies --- .../src/physical_plan/expressions/get_indexed_field.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs b/datafusion/src/physical_plan/expressions/get_indexed_field.rs index 0f004743ec498..7e60698aa3112 100644 --- a/datafusion/src/physical_plan/expressions/get_indexed_field.rs +++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs @@ -317,7 +317,7 @@ mod tests { let result = result .as_any() .downcast_ref::() - .expect(&format!("failed to downcast to ListArray : {:?}", result)); + .unwrap_or_else(|| panic!("failed to downcast to ListArray : {:?}", result)); let expected = &build_utf8_lists(list_of_tuples.into_iter().map(|t| t.1).collect()); assert_eq!(expected, result); @@ -333,7 +333,9 @@ mod tests { let result = result .as_any() .downcast_ref::() - .expect(&format!("failed to downcast to StringArray : {:?}", result)); + .unwrap_or_else(|| { + panic!("failed to downcast to StringArray : {:?}", result) + }); let expected = &StringArray::from(expected); assert_eq!(expected, result); } @@ -359,7 +361,7 @@ mod tests { get_indexed_field_mixed_test( list_of_structs.clone(), expected_list, - expected_ints.clone(), + expected_ints, )?; Ok(()) }