diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 4fef60020f779..b0769df1e9dbb 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -101,7 +101,9 @@ pub enum ScalarValue { FixedSizeBinary(i32, Option>), /// large binary LargeBinary(Option>), - /// list of nested ScalarValue + /// Fixed size list of nested ScalarValue + Fixedsizelist(Option>, FieldRef, i32), + /// List of nested ScalarValue List(Option>, FieldRef), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), @@ -196,6 +198,10 @@ impl PartialEq for ScalarValue { (FixedSizeBinary(_, _), _) => false, (LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2), (LargeBinary(_), _) => false, + (Fixedsizelist(v1, t1, l1), Fixedsizelist(v2, t2, l2)) => { + v1.eq(v2) && t1.eq(t2) && l1.eq(l2) + } + (Fixedsizelist(_, _, _), _) => false, (List(v1, t1), List(v2, t2)) => v1.eq(v2) && t1.eq(t2), (List(_, _), _) => false, (Date32(v1), Date32(v2)) => v1.eq(v2), @@ -315,6 +321,14 @@ impl PartialOrd for ScalarValue { (FixedSizeBinary(_, _), _) => None, (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), (LargeBinary(_), _) => None, + (Fixedsizelist(v1, t1, l1), Fixedsizelist(v2, t2, l2)) => { + if t1.eq(t2) && l1.eq(l2) { + v1.partial_cmp(v2) + } else { + None + } + } + (Fixedsizelist(_, _, _), _) => None, (List(v1, t1), List(v2, t2)) => { if t1.eq(t2) { v1.partial_cmp(v2) @@ -1518,6 +1532,11 @@ impl std::hash::Hash for ScalarValue { Binary(v) => v.hash(state), FixedSizeBinary(_, v) => v.hash(state), LargeBinary(v) => v.hash(state), + Fixedsizelist(v, t, l) => { + v.hash(state); + t.hash(state); + l.hash(state); + } List(v, t) => { v.hash(state); t.hash(state); @@ -1994,6 +2013,10 @@ impl ScalarValue { ScalarValue::Binary(_) => DataType::Binary, ScalarValue::FixedSizeBinary(sz, _) => DataType::FixedSizeBinary(*sz), ScalarValue::LargeBinary(_) => DataType::LargeBinary, + ScalarValue::Fixedsizelist(_, field, length) => DataType::FixedSizeList( + Arc::new(Field::new("item", field.data_type().clone(), true)), + *length, + ), ScalarValue::List(_, field) => DataType::List(Arc::new(Field::new( "item", field.data_type().clone(), @@ -2142,6 +2165,7 @@ impl ScalarValue { ScalarValue::Binary(v) => v.is_none(), ScalarValue::FixedSizeBinary(_, v) => v.is_none(), ScalarValue::LargeBinary(v) => v.is_none(), + ScalarValue::Fixedsizelist(v, ..) => v.is_none(), ScalarValue::List(v, _) => v.is_none(), ScalarValue::Date32(v) => v.is_none(), ScalarValue::Date64(v) => v.is_none(), @@ -2847,6 +2871,9 @@ impl ScalarValue { .collect::(), ), }, + ScalarValue::Fixedsizelist(..) => { + unimplemented!("FixedSizeList is not supported yet") + } ScalarValue::List(values, field) => Arc::new(match field.data_type() { DataType::Boolean => build_list!(BooleanBuilder, Boolean, values, size), DataType::Int8 => build_list!(Int8Builder, Int8, values, size), @@ -3294,6 +3321,7 @@ impl ScalarValue { ScalarValue::LargeBinary(val) => { eq_array_primitive!(array, index, LargeBinaryArray, val) } + ScalarValue::Fixedsizelist(..) => unimplemented!(), ScalarValue::List(_, _) => unimplemented!(), ScalarValue::Date32(val) => { eq_array_primitive!(array, index, Date32Array, val) @@ -3414,7 +3442,8 @@ impl ScalarValue { | ScalarValue::LargeBinary(b) => { b.as_ref().map(|b| b.capacity()).unwrap_or_default() } - ScalarValue::List(vals, field) => { + ScalarValue::Fixedsizelist(vals, field, _) + | ScalarValue::List(vals, field) => { vals.as_ref() .map(|vals| Self::size_of_vec(vals) - std::mem::size_of_val(vals)) .unwrap_or_default() @@ -3732,29 +3761,9 @@ impl fmt::Display for ScalarValue { ScalarValue::TimestampNanosecond(e, _) => format_option!(f, e)?, ScalarValue::Utf8(e) => format_option!(f, e)?, ScalarValue::LargeUtf8(e) => format_option!(f, e)?, - ScalarValue::Binary(e) => match e { - Some(l) => write!( - f, - "{}", - l.iter() - .map(|v| format!("{v}")) - .collect::>() - .join(",") - )?, - None => write!(f, "NULL")?, - }, - ScalarValue::FixedSizeBinary(_, e) => match e { - Some(l) => write!( - f, - "{}", - l.iter() - .map(|v| format!("{v}")) - .collect::>() - .join(",") - )?, - None => write!(f, "NULL")?, - }, - ScalarValue::LargeBinary(e) => match e { + ScalarValue::Binary(e) + | ScalarValue::FixedSizeBinary(_, e) + | ScalarValue::LargeBinary(e) => match e { Some(l) => write!( f, "{}", @@ -3765,7 +3774,7 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, - ScalarValue::List(e, _) => match e { + ScalarValue::Fixedsizelist(e, ..) | ScalarValue::List(e, _) => match e { Some(l) => write!( f, "{}", @@ -3849,6 +3858,7 @@ impl fmt::Debug for ScalarValue { } ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({self})"), ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{self}\")"), + ScalarValue::Fixedsizelist(..) => write!(f, "FixedSizeList([{self}])"), ScalarValue::List(_, _) => write!(f, "List([{self}])"), ScalarValue::Date32(_) => write!(f, "Date32(\"{self}\")"), ScalarValue::Date64(_) => write!(f, "Date64(\"{self}\")"), diff --git a/datafusion/core/tests/data/fixed_size_list_array.parquet b/datafusion/core/tests/data/fixed_size_list_array.parquet new file mode 100644 index 0000000000000..aafc5ce62f52a Binary files /dev/null and b/datafusion/core/tests/data/fixed_size_list_array.parquet differ diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt index 0d99e6cbb3a1d..1f43c5f8e154e 100644 --- a/datafusion/core/tests/sqllogictests/test_files/array.slt +++ b/datafusion/core/tests/sqllogictests/test_files/array.slt @@ -417,8 +417,6 @@ select make_array(x, y) from foo2; # array_contains - - # array_contains scalar function #1 query BBB rowsort select array_contains(make_array(1, 2, 3), make_array(1, 1, 2, 3)), array_contains([1, 2, 3], [1, 1, 2]), array_contains([1, 2, 3], [2, 1, 3, 1]); @@ -531,3 +529,40 @@ SELECT FROM t ---- true true + +statement ok +CREATE EXTERNAL TABLE fixed_size_list_array STORED AS PARQUET LOCATION 'tests/data/fixed_size_list_array.parquet'; + +query T +select arrow_typeof(f0) from fixed_size_list_array; +---- +FixedSizeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 2) +FixedSizeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 2) + +query ? +select * from fixed_size_list_array; +---- +[1, 2] +[3, 4] + +query ? +select f0 from fixed_size_list_array; +---- +[1, 2] +[3, 4] + +query ? +select arrow_cast(f0, 'List(Int64)') from fixed_size_list_array; +---- +[1, 2] +[3, 4] + +query ? +select make_array(arrow_cast(f0, 'List(Int64)')) from fixed_size_list_array +---- +[[1, 2], [3, 4]] + +query ? +select make_array(f0) from fixed_size_list_array +---- +[[1, 2], [3, 4]] diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 5d1fef53520ba..7cf4a233f7e0b 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -330,8 +330,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, &fun.signature, )?; - let expr = Expr::ScalarUDF(ScalarUDF::new(fun, new_expr)); - Ok(expr) + Ok(Expr::ScalarUDF(ScalarUDF::new(fun, new_expr))) } Expr::ScalarFunction(ScalarFunction { fun, args }) => { let new_args = coerce_arguments_for_signature( @@ -520,7 +519,7 @@ fn coerce_window_frame( fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) -> Result { let left_type = expr.get_type(schema)?; get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?; - expr.clone().cast_to(&DataType::Boolean, schema) + cast_expr(expr, &DataType::Boolean, schema) } /// Returns `expressions` coerced to types compatible with @@ -559,6 +558,25 @@ fn coerce_arguments_for_fun( return Ok(vec![]); } + let mut expressions: Vec = expressions.to_vec(); + + // Cast Fixedsizelist to List for array functions + if *fun == BuiltinScalarFunction::MakeArray { + expressions = expressions + .into_iter() + .map(|expr| { + let data_type = expr.get_type(schema).unwrap(); + if let DataType::FixedSizeList(field, _) = data_type { + let field = field.as_ref().clone(); + let to_type = DataType::List(Arc::new(field)); + expr.cast_to(&to_type, schema) + } else { + Ok(expr) + } + }) + .collect::>>()?; + } + if *fun == BuiltinScalarFunction::MakeArray { // Find the final data type for the function arguments let current_types = expressions @@ -579,8 +597,7 @@ fn coerce_arguments_for_fun( .map(|(expr, from_type)| cast_array_expr(expr, &from_type, &new_type, schema)) .collect(); } - - Ok(expressions.to_vec()) + Ok(expressions) } /// Cast `expr` to the specified type, if possible @@ -598,7 +615,7 @@ fn cast_array_expr( if from_type.equals_datatype(&DataType::Null) { Ok(expr.clone()) } else { - expr.clone().cast_to(to_type, schema) + cast_expr(expr, to_type, schema) } } @@ -625,7 +642,7 @@ fn coerce_agg_exprs_for_signature( input_exprs .iter() .enumerate() - .map(|(i, expr)| expr.clone().cast_to(&coerced_types[i], schema)) + .map(|(i, expr)| cast_expr(expr, &coerced_types[i], schema)) .collect::>>() } @@ -746,6 +763,7 @@ mod test { use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::Field; use datafusion_common::tree_node::TreeNode; use datafusion_common::{DFField, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; @@ -763,7 +781,7 @@ mod test { use datafusion_physical_expr::expressions::AvgAccumulator; use crate::analyzer::type_coercion::{ - coerce_case_expression, TypeCoercion, TypeCoercionRewriter, + cast_expr, coerce_case_expression, TypeCoercion, TypeCoercionRewriter, }; use crate::test::assert_analyzed_plan_eq; @@ -1220,6 +1238,58 @@ mod test { Ok(()) } + #[test] + fn test_casting_for_fixed_size_list() -> Result<()> { + let val = lit(ScalarValue::Fixedsizelist( + Some(vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ]), + Arc::new(Field::new("item", DataType::Int32, true)), + 3, + )); + let expr = Expr::ScalarFunction(ScalarFunction { + fun: BuiltinScalarFunction::MakeArray, + args: vec![val.clone()], + }); + let schema = Arc::new(DFSchema::new_with_metadata( + vec![DFField::new_unqualified( + "item", + DataType::FixedSizeList( + Arc::new(Field::new("a", DataType::Int32, true)), + 3, + ), + true, + )], + std::collections::HashMap::new(), + )?); + let mut rewriter = TypeCoercionRewriter { schema }; + let result = expr.rewrite(&mut rewriter)?; + + let schema = Arc::new(DFSchema::new_with_metadata( + vec![DFField::new_unqualified( + "item", + DataType::List(Arc::new(Field::new("a", DataType::Int32, true))), + true, + )], + std::collections::HashMap::new(), + )?); + let expected_casted_expr = cast_expr( + &val, + &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + &schema, + )?; + + let expected = Expr::ScalarFunction(ScalarFunction { + fun: BuiltinScalarFunction::MakeArray, + args: vec![expected_casted_expr], + }); + + assert_eq!(result, expected); + Ok(()) + } + #[test] fn test_type_coercion_rewrite() -> Result<()> { // gt diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 911c94b06d765..bddeef526a4d0 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -111,7 +111,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { DataType::List(..) => { let arrays = downcast_vec!(args, ListArray).collect::>>()?; - let len: i32 = arrays.len() as i32; + let len = arrays.iter().map(|arr| arr.len() as i32).sum(); let capacity = Capacities::Array(arrays.iter().map(|a| a.get_array_memory_size()).sum()); let array_data: Vec<_> = @@ -125,7 +125,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { } let list_data_type = - DataType::List(Arc::new(Field::new("item", data_type, false))); + DataType::List(Arc::new(Field::new("item", data_type, true))); let list_data = ArrayData::builder(list_data_type) .len(1) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index d81e92c3f3d34..42702d6b28d01 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1068,6 +1068,10 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { Value::LargeUtf8Value(s.to_owned()) }) } + ScalarValue::Fixedsizelist(..) => Err(Error::General( + "Proto serialization error: ScalarValue::Fixedsizelist not supported" + .to_string(), + )), ScalarValue::List(values, boxed_field) => { let is_null = values.is_none(); diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index 91a42f4736421..46957a9cdd86c 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -18,9 +18,9 @@ //! Implementation of the `arrow_cast` function that allows //! casting to arbitrary arrow types (rather than SQL types) -use std::{fmt::Display, iter::Peekable, str::Chars}; +use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc}; -use arrow_schema::{DataType, IntervalUnit, TimeUnit}; +use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit}; use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::{Expr, ExprSchemable}; @@ -150,6 +150,7 @@ impl<'a> Parser<'a> { Token::Decimal128 => self.parse_decimal_128(), Token::Decimal256 => self.parse_decimal_256(), Token::Dictionary => self.parse_dictionary(), + Token::List => self.parse_list(), tok => Err(make_error( self.val, &format!("finding next type, got unexpected '{tok}'"), @@ -157,6 +158,16 @@ impl<'a> Parser<'a> { } } + /// Parses the List type + fn parse_list(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let data_type = self.parse_next_type()?; + self.expect_token(Token::RParen)?; + Ok(DataType::List(Arc::new(Field::new( + "item", data_type, true, + )))) + } + /// Parses the next timeunit fn parse_time_unit(&mut self, context: &str) -> Result { match self.next_token()? { @@ -486,6 +497,8 @@ impl<'a> Tokenizer<'a> { "Date32" => Token::SimpleType(DataType::Date32), "Date64" => Token::SimpleType(DataType::Date64), + "List" => Token::List, + "Second" => Token::TimeUnit(TimeUnit::Second), "Millisecond" => Token::TimeUnit(TimeUnit::Millisecond), "Microsecond" => Token::TimeUnit(TimeUnit::Microsecond), @@ -573,12 +586,14 @@ enum Token { None, Integer(i64), DoubleQuotedString(String), + List, } impl Display for Token { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Token::SimpleType(t) => write!(f, "{t}"), + Token::List => write!(f, "List"), Token::Timestamp => write!(f, "Timestamp"), Token::Time32 => write!(f, "Time32"), Token::Time64 => write!(f, "Time64"),