diff --git a/datafusion/core/src/logical_plan/mod.rs b/datafusion/core/src/logical_plan/mod.rs index 87a02ae0118c0..39d3af4a20ce3 100644 --- a/datafusion/core/src/logical_plan/mod.rs +++ b/datafusion/core/src/logical_plan/mod.rs @@ -27,11 +27,11 @@ pub use datafusion_common::{ Column, DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema, }; pub use datafusion_expr::{ - abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan, - atan2, avg, bit_length, btrim, call_fn, case, cast, ceil, character_length, chr, - coalesce, col, combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos, - count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, - exists, exp, expr_rewriter, + abs, acos, and, approx_distinct, approx_percentile_cont, ascii, asin, atan, atan2, + avg, bit_length, btrim, call_fn, case, cast, ceil, character_length, chr, coalesce, + col, combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos, count, + count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exists, exp, + expr_rewriter, expr_rewriter::{ normalize_col, normalize_col_with_schemas, normalize_cols, replace_col, rewrite_sort_cols_by_aggs, unnormalize_col, unnormalize_cols, ExprRewritable, @@ -50,11 +50,11 @@ pub use datafusion_expr::{ StringifiedPlan, Subquery, TableScan, ToStringifiedPlan, Union, UserDefinedLogicalNode, Values, }, - lower, lpad, ltrim, max, md5, min, not_exists, not_in_subquery, now, now_expr, - nullif, octet_length, or, power, random, regexp_match, regexp_replace, repeat, - replace, reverse, right, round, rpad, rtrim, scalar_subquery, sha224, sha256, sha384, - sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, - to_timestamp_micros, to_timestamp_millis, to_timestamp_seconds, translate, trim, - trunc, unalias, upper, when, Expr, ExprSchemable, Literal, Operator, + lower, lpad, ltrim, make_array, max, md5, min, not_exists, not_in_subquery, now, + now_expr, nullif, octet_length, or, power, random, regexp_match, regexp_replace, + repeat, replace, reverse, right, round, rpad, rtrim, scalar_subquery, sha224, sha256, + sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, + to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_seconds, translate, + trim, trunc, unalias, upper, when, Expr, ExprSchemable, Literal, Operator, }; pub use datafusion_optimizer::expr_simplifier::{ExprSimplifiable, SimplifyInfo}; diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index edae225d87474..a4cc8e5c3c927 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -31,10 +31,10 @@ pub use crate::execution::options::{ AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, }; pub use crate::logical_plan::{ - approx_percentile_cont, array, ascii, avg, bit_length, btrim, cast, character_length, - chr, coalesce, col, concat, concat_ws, count, create_udf, date_part, date_trunc, - digest, exists, from_unixtime, in_list, in_subquery, initcap, left, length, lit, - lower, lpad, ltrim, max, md5, min, not_exists, not_in_subquery, now, octet_length, + approx_percentile_cont, ascii, avg, bit_length, btrim, cast, character_length, chr, + coalesce, col, concat, concat_ws, count, create_udf, date_part, date_trunc, digest, + exists, from_unixtime, in_list, in_subquery, initcap, left, length, lit, lower, lpad, + ltrim, make_array, max, md5, min, not_exists, not_in_subquery, now, octet_length, random, regexp_match, regexp_replace, repeat, replace, reverse, right, rpad, rtrim, scalar_subquery, sha224, sha256, sha384, sha512, split_part, starts_with, strpos, substr, sum, to_hex, translate, trim, upper, Column, Expr, JoinType, Partitioning, diff --git a/datafusion/core/tests/sql/functions.rs b/datafusion/core/tests/sql/functions.rs index 88c00b45ef1cf..aa5a6725dcf59 100644 --- a/datafusion/core/tests/sql/functions.rs +++ b/datafusion/core/tests/sql/functions.rs @@ -111,8 +111,8 @@ async fn query_concat() -> Result<()> { Ok(()) } -#[tokio::test] -async fn query_array() -> Result<()> { +// Return a session context with table "test" registered with 2 columns +fn array_context() -> SessionContext { let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Utf8, false), Field::new("c2", DataType::Int32, true), @@ -124,43 +124,110 @@ async fn query_array() -> Result<()> { Arc::new(StringArray::from_slice(&["", "a", "aa", "aaa"])), Arc::new(Int32Array::from(vec![Some(0), Some(1), None, Some(3)])), ], - )?; + ) + .unwrap(); - let table = MemTable::try_new(schema, vec![vec![data]])?; + let table = MemTable::try_new(schema, vec![vec![data]]).unwrap(); let ctx = SessionContext::new(); - ctx.register_table("test", Arc::new(table))?; - let sql = "SELECT array(c1, cast(c2 as varchar)) FROM test"; + ctx.register_table("test", Arc::new(table)).unwrap(); + ctx +} + +#[tokio::test] +async fn query_array() { + let ctx = array_context(); + let sql = "SELECT array[c1, cast(c2 as varchar)] FROM test"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+--------------------------------------+", - "| array(test.c1,CAST(test.c2 AS Utf8)) |", - "+--------------------------------------+", - "| [, 0] |", - "| [a, 1] |", - "| [aa, ] |", - "| [aaa, 3] |", - "+--------------------------------------+", + "+----------+", + "| array |", + "+----------+", + "| [, 0] |", + "| [a, 1] |", + "| [aa, ] |", + "| [aaa, 3] |", + "+----------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn query_make_array() { + let ctx = array_context(); + let sql = "SELECT make_array(c1, cast(c2 as varchar)) FROM test"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------------------------------------+", + "| makearray(test.c1,CAST(test.c2 AS Utf8)) |", + "+------------------------------------------+", + "| [, 0] |", + "| [a, 1] |", + "| [aa, ] |", + "| [aaa, 3] |", + "+------------------------------------------+", ]; assert_batches_eq!(expected, &actual); - Ok(()) } #[tokio::test] -async fn query_array_scalar() -> Result<()> { +async fn query_array_scalar() { let ctx = SessionContext::new(); - let sql = "SELECT array(1, 2, 3);"; + let sql = "SELECT array[1, 2, 3];"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+-----------------------------------+", - "| array(Int64(1),Int64(2),Int64(3)) |", - "+-----------------------------------+", - "| [1, 2, 3] |", - "+-----------------------------------+", + "+-----------+", + "| array |", + "+-----------+", + "| [1, 2, 3] |", + "+-----------+", + ]; + assert_batches_eq!(expected, &actual); + + // alternate syntax format + let sql = "SELECT [1, 2, 3];"; + let actual = execute_to_batches(&ctx, sql).await; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn query_array_scalar_bad_types() { + let ctx = SessionContext::new(); + + // no common type to coerce to, should error + let err = plan_and_collect(&ctx, "SELECT [1, true, null]") + .await + .unwrap_err(); + assert_eq!(err.to_string(), "Error during planning: Coercion from [Int64, Boolean, Null] to the signature VariadicEqual failed.",); +} + +#[tokio::test] +async fn query_array_scalar_coerce() { + let ctx = SessionContext::new(); + + // The planner should be able to coerce this to all integers + // https://github.com/apache/arrow-datafusion/issues/3170 + let err = plan_and_collect(&ctx, "SELECT [1, 2, '3']") + .await + .unwrap_err(); + assert_eq!(err.to_string(), "Error during planning: Coercion from [Int64, Int64, Utf8] to the signature VariadicEqual failed.",); +} + +#[tokio::test] +async fn query_make_array_scalar() { + let ctx = SessionContext::new(); + + let sql = "SELECT make_array(1, 2, 3);"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+---------------------------------------+", + "| makearray(Int64(1),Int64(2),Int64(3)) |", + "+---------------------------------------+", + "| [1, 2, 3] |", + "+---------------------------------------+", ]; assert_batches_eq!(expected, &actual); - Ok(()) } #[tokio::test] diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 532699a37cbbb..45214266fccf5 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -71,9 +71,11 @@ pub enum BuiltinScalarFunction { /// trunc Trunc, - // string functions + // array functions /// construct an array from columns - Array, + MakeArray, + + // string functions /// ascii Ascii, /// bit_length @@ -204,7 +206,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Sqrt => Volatility::Immutable, BuiltinScalarFunction::Tan => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, - BuiltinScalarFunction::Array => Volatility::Immutable, + BuiltinScalarFunction::MakeArray => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, BuiltinScalarFunction::BitLength => Volatility::Immutable, BuiltinScalarFunction::Btrim => Volatility::Immutable, @@ -297,8 +299,10 @@ impl FromStr for BuiltinScalarFunction { // conditional functions "coalesce" => BuiltinScalarFunction::Coalesce, + // array functions + "make_array" => BuiltinScalarFunction::MakeArray, + // string functions - "array" => BuiltinScalarFunction::Array, "ascii" => BuiltinScalarFunction::Ascii, "bit_length" => BuiltinScalarFunction::BitLength, "btrim" => BuiltinScalarFunction::Btrim, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 09ac0c2870413..1731d42640d4b 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -382,9 +382,9 @@ scalar_expr!(FromUnixtime, from_unixtime, unixtime); unary_scalar_expr!(ArrowTypeof, arrow_typeof, "data type"); /// Returns an array of fixed size with each argument on it. -pub fn array(args: Vec) -> Expr { +pub fn make_array(args: Vec) -> Expr { Expr::ScalarFunction { - fun: built_in_function::BuiltinScalarFunction::Array, + fun: built_in_function::BuiltinScalarFunction::MakeArray, args, } } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 5cf42fbd21243..1d7de6b651ebc 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -21,8 +21,8 @@ use crate::nullif::SUPPORTED_NULLIF_TYPES; use crate::type_coercion::data_types; use crate::ColumnarValue; use crate::{ - array_expressions, conditional_expressions, struct_expressions, Accumulator, - BuiltinScalarFunction, Signature, TypeSignature, + conditional_expressions, struct_expressions, Accumulator, BuiltinScalarFunction, + Signature, TypeSignature, }; use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; use datafusion_common::{DataFusionError, Result}; @@ -96,7 +96,7 @@ pub fn return_type( // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match fun { - BuiltinScalarFunction::Array => Ok(DataType::FixedSizeList( + BuiltinScalarFunction::MakeArray => Ok(DataType::FixedSizeList( Box::new(Field::new("item", input_expr_types[0].clone(), true)), input_expr_types.len() as i32, )), @@ -267,12 +267,8 @@ pub fn return_type( pub fn signature(fun: &BuiltinScalarFunction) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. - // for now, the list is small, as we do not have many built-in functions. match fun { - BuiltinScalarFunction::Array => Signature::variadic( - array_expressions::SUPPORTED_ARRAY_TYPES.to_vec(), - fun.volatility(), - ), + BuiltinScalarFunction::MakeArray => Signature::variadic_equal(fun.volatility()), BuiltinScalarFunction::Struct => Signature::variadic( struct_expressions::SUPPORTED_STRUCT_TYPES.to_vec(), fun.volatility(), diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 84e6732e39997..216ccef46d433 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -34,7 +34,10 @@ macro_rules! downcast_vec { }}; } -macro_rules! array { +/// Create an array of FixedSizeList from a set of individual Arrays +/// where each element in the output FixedSizeList is the result of +/// concatenating the corresponding values in the input Arrays +macro_rules! make_fixed_size_list { ($ARGS:expr, $ARRAY_TYPE:ident, $BUILDER_TYPE:ident) => {{ // downcast all arguments to their common format let args = @@ -59,7 +62,7 @@ macro_rules! array { }}; } -fn array_array(args: &[ArrayRef]) -> Result { +fn arrays_to_fixed_size_list_array(args: &[ArrayRef]) -> Result { // do not accept 0 arguments. if args.is_empty() { return Err(DataFusionError::Internal( @@ -68,19 +71,21 @@ fn array_array(args: &[ArrayRef]) -> Result { } let res = match args[0].data_type() { - DataType::Utf8 => array!(args, StringArray, StringBuilder), - DataType::LargeUtf8 => array!(args, LargeStringArray, LargeStringBuilder), - DataType::Boolean => array!(args, BooleanArray, BooleanBuilder), - DataType::Float32 => array!(args, Float32Array, Float32Builder), - DataType::Float64 => array!(args, Float64Array, Float64Builder), - DataType::Int8 => array!(args, Int8Array, Int8Builder), - DataType::Int16 => array!(args, Int16Array, Int16Builder), - DataType::Int32 => array!(args, Int32Array, Int32Builder), - DataType::Int64 => array!(args, Int64Array, Int64Builder), - DataType::UInt8 => array!(args, UInt8Array, UInt8Builder), - DataType::UInt16 => array!(args, UInt16Array, UInt16Builder), - DataType::UInt32 => array!(args, UInt32Array, UInt32Builder), - DataType::UInt64 => array!(args, UInt64Array, UInt64Builder), + DataType::Utf8 => make_fixed_size_list!(args, StringArray, StringBuilder), + DataType::LargeUtf8 => { + make_fixed_size_list!(args, LargeStringArray, LargeStringBuilder) + } + DataType::Boolean => make_fixed_size_list!(args, BooleanArray, BooleanBuilder), + DataType::Float32 => make_fixed_size_list!(args, Float32Array, Float32Builder), + DataType::Float64 => make_fixed_size_list!(args, Float64Array, Float64Builder), + DataType::Int8 => make_fixed_size_list!(args, Int8Array, Int8Builder), + DataType::Int16 => make_fixed_size_list!(args, Int16Array, Int16Builder), + DataType::Int32 => make_fixed_size_list!(args, Int32Array, Int32Builder), + DataType::Int64 => make_fixed_size_list!(args, Int64Array, Int64Builder), + DataType::UInt8 => make_fixed_size_list!(args, UInt8Array, UInt8Builder), + DataType::UInt16 => make_fixed_size_list!(args, UInt16Array, UInt16Builder), + DataType::UInt32 => make_fixed_size_list!(args, UInt32Array, UInt32Builder), + DataType::UInt64 => make_fixed_size_list!(args, UInt64Array, UInt64Builder), data_type => { return Err(DataFusionError::NotImplemented(format!( "Array is not implemented for type '{:?}'.", @@ -92,7 +97,7 @@ fn array_array(args: &[ArrayRef]) -> Result { } /// put values in an array. -pub fn array(values: &[ColumnarValue]) -> Result { +pub fn make_array(values: &[ColumnarValue]) -> Result { let arrays: Vec = values .iter() .map(|x| match x { @@ -100,5 +105,7 @@ pub fn array(values: &[ColumnarValue]) -> Result { ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), }) .collect(); - Ok(ColumnarValue::Array(array_array(arrays.as_slice())?)) + Ok(ColumnarValue::Array(arrays_to_fixed_size_list_array( + arrays.as_slice(), + )?)) } diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index dde0ee0a06bef..6d833a02651b5 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -322,7 +322,7 @@ pub fn create_physical_fun( } // string functions - BuiltinScalarFunction::Array => Arc::new(array_expressions::array), + BuiltinScalarFunction::MakeArray => Arc::new(array_expressions::make_array), BuiltinScalarFunction::Struct => Arc::new(struct_expressions::struct_expr), BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -2727,7 +2727,7 @@ mod tests { value2: ArrayRef, expected_type: DataType, expected: &str, - ) -> Result<()> { + ) { // any type works here: we evaluate against a literal of `value` let schema = Schema::new(vec![ Field::new("a", value1.data_type().clone(), false), @@ -2737,22 +2737,23 @@ mod tests { let execution_props = ExecutionProps::new(); let expr = create_physical_expr( - &BuiltinScalarFunction::Array, - &[col("a", &schema)?, col("b", &schema)?], + &BuiltinScalarFunction::MakeArray, + &[col("a", &schema).unwrap(), col("b", &schema).unwrap()], &schema, &execution_props, - )?; + ) + .unwrap(); // type is correct assert_eq!( - expr.data_type(&schema)?, + expr.data_type(&schema).unwrap(), // type equals to a common coercion DataType::FixedSizeList(Box::new(Field::new("item", expected_type, true)), 2) ); // evaluate works - let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns).unwrap(); + let result = expr.evaluate(&batch).unwrap().into_array(batch.num_rows()); // downcast works let result = result @@ -2762,28 +2763,20 @@ mod tests { // value is correct assert_eq!(format!("{:?}", result.value(0)), expected); - - Ok(()) } #[test] - fn test_array() -> Result<()> { + fn test_array() { generic_test_array( Arc::new(StringArray::from_slice(&["aa"])), Arc::new(StringArray::from_slice(&["bb"])), DataType::Utf8, "StringArray\n[\n \"aa\",\n \"bb\",\n]", - )?; - - // different types, to validate that casting happens - generic_test_array( - Arc::new(UInt32Array::from_slice(&[1u32])), - Arc::new(UInt64Array::from_slice(&[1u64])), - DataType::UInt64, - "PrimitiveArray\n[\n 1,\n 1,\n]", - )?; + ); - // different types (another order), to validate that casting happens + // different types (first argument type is used, so can cast, + // to validate that casting happens (coercion from u32 to u64 + // is ok, but u64 to u32 might lose data). generic_test_array( Arc::new(UInt64Array::from_slice(&[1u64])), Arc::new(UInt32Array::from_slice(&[1u32])), diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 524b03bd69333..a71aa6cb06c1d 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -32,16 +32,17 @@ use datafusion_common::{ use datafusion_expr::expr::GroupingSet; use datafusion_expr::expr::GroupingSet::GroupingSets; use datafusion_expr::{ - abs, acos, array, ascii, asin, atan, atan2, bit_length, btrim, ceil, - character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, date_bin, - date_part, date_trunc, digest, exp, floor, from_unixtime, left, ln, log10, log2, + abs, acos, ascii, asin, atan, atan2, bit_length, btrim, ceil, character_length, chr, + coalesce, concat_expr, concat_ws_expr, cos, date_bin, date_part, date_trunc, digest, + exp, floor, from_unixtime, left, ln, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lower, lpad, ltrim, md5, now_expr, nullif, octet_length, power, random, regexp_match, - regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, - sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, tan, - to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_seconds, translate, - trim, trunc, upper, AggregateFunction, BuiltInWindowFunction, BuiltinScalarFunction, - Expr, Operator, WindowFrame, WindowFrameBound, WindowFrameUnits, + lower, lpad, ltrim, make_array, md5, now_expr, nullif, octet_length, power, random, + regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, + sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, + substr, tan, to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_seconds, + translate, trim, trunc, upper, AggregateFunction, BuiltInWindowFunction, + BuiltinScalarFunction, Expr, Operator, WindowFrame, WindowFrameBound, + WindowFrameUnits, }; use std::sync::Arc; @@ -431,7 +432,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Ltrim => Self::Ltrim, ScalarFunction::Rtrim => Self::Rtrim, ScalarFunction::ToTimestamp => Self::ToTimestamp, - ScalarFunction::Array => Self::Array, + ScalarFunction::Array => Self::MakeArray, ScalarFunction::NullIf => Self::NullIf, ScalarFunction::DatePart => Self::DatePart, ScalarFunction::DateTrunc => Self::DateTrunc, @@ -968,7 +969,7 @@ pub fn parse_expr( match scalar_function { ScalarFunction::Asin => Ok(asin(parse_expr(&args[0], registry)?)), ScalarFunction::Acos => Ok(acos(parse_expr(&args[0], registry)?)), - ScalarFunction::Array => Ok(array( + ScalarFunction::Array => Ok(make_array( args.to_owned() .iter() .map(|expr| parse_expr(expr, registry)) diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 045b97a3188df..a022769dcab19 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -1083,7 +1083,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Ltrim => Self::Ltrim, BuiltinScalarFunction::Rtrim => Self::Rtrim, BuiltinScalarFunction::ToTimestamp => Self::ToTimestamp, - BuiltinScalarFunction::Array => Self::Array, + BuiltinScalarFunction::MakeArray => Self::Array, BuiltinScalarFunction::NullIf => Self::NullIf, BuiltinScalarFunction::DatePart => Self::DatePart, BuiltinScalarFunction::DateTrunc => Self::DateTrunc, diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 28c82f80246f9..9b1bd01e3755e 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -1690,7 +1690,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fractional_seconds_precision, ), - SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema), + SQLExpr::Array(arr) => self.sql_array_expr(arr.elem, schema), SQLExpr::Identifier(id) => { if id.value.starts_with('@') { @@ -2383,50 +2383,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .is_ok() } - fn sql_array_literal( - &self, - elements: Vec, - schema: &DFSchema, - ) -> Result { - let mut values = Vec::with_capacity(elements.len()); - - for element in elements { - let value = - self.sql_expr_to_logical_expr(element, schema, &mut HashMap::new())?; - match value { - Expr::Literal(scalar) => { - values.push(scalar); - } - _ => { - return Err(DataFusionError::NotImplemented(format!( - "Arrays with elements other than literal are not supported: {}", - value - ))); - } - } - } - - let data_types: HashSet = - values.iter().map(|e| e.get_datatype()).collect(); - - if data_types.is_empty() { - Ok(Expr::Literal(ScalarValue::List( - None, - Box::new(Field::new("item", DataType::Utf8, true)), - ))) - } else if data_types.len() > 1 { - Err(DataFusionError::NotImplemented(format!( - "Arrays with different types are not supported: {:?}", - data_types, - ))) - } else { - let data_type = values[0].get_datatype(); + fn sql_array_expr(&self, elements: Vec, schema: &DFSchema) -> Result { + let args: Vec = elements + .into_iter() + .map(|expr| self.sql_expr_to_logical_expr(expr, schema, &mut HashMap::new())) + .collect::>()?; - Ok(Expr::Literal(ScalarValue::List( - Some(values), - Box::new(Field::new("item", data_type, true)), - ))) - } + let fun = BuiltinScalarFunction::MakeArray; + // follow postgres convention and name result "array" + Ok(Expr::ScalarFunction { fun, args }.alias("array")) } } @@ -2617,7 +2582,6 @@ fn parse_sql_number(n: &str) -> Result { #[cfg(test)] mod tests { use super::*; - use crate::assert_contains; use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect}; use std::any::Any; @@ -3353,25 +3317,12 @@ mod tests { ); } - #[test] - fn select_array_no_common_type() { - let sql = "SELECT [1, true, null]"; - let err = logical_plan(sql).expect_err("query should have failed"); - - // HashSet doesn't guarantee order - assert_contains!( - err.to_string(), - r#"Arrays with different types are not supported: "# - ); - } - #[test] fn select_array_non_literal_type() { - let sql = "SELECT [now()]"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - r#"NotImplemented("Arrays with elements other than literal are not supported: now()")"#, - format!("{:?}", err) + quick_test( + "SELECT [now()]", + "Projection: makearray(now()) AS array\ + \n EmptyRelation", ); }