diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 2aa4f2e45fc..245ca3aaaa8 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -1144,9 +1144,9 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { let expr = create_name(expr, input_schema)?; let list = list.iter().map(|expr| create_name(expr, input_schema)); if *negated { - Ok(format!("{:?} NOT IN ({:?})", expr, list)) + Ok(format!("{} NOT IN ({:?})", expr, list)) } else { - Ok(format!("{:?} IN ({:?})", expr, list)) + Ok(format!("{} IN ({:?})", expr, list)) } } other => Err(DataFusionError::NotImplemented(format!( diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 6244387e180..0de0a032520 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -29,7 +29,6 @@ mod extension; mod operators; mod plan; mod registry; - pub use builder::LogicalPlanBuilder; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; diff --git a/rust/datafusion/src/physical_plan/datetime_expressions.rs b/rust/datafusion/src/physical_plan/datetime_expressions.rs index 8642e3b40e3..3d363ce97d2 100644 --- a/rust/datafusion/src/physical_plan/datetime_expressions.rs +++ b/rust/datafusion/src/physical_plan/datetime_expressions.rs @@ -16,27 +16,30 @@ // under the License. //! DateTime expressions - use std::sync::Arc; +use super::ColumnarValue; use crate::{ error::{DataFusionError, Result}, scalar::{ScalarType, ScalarValue}, }; -use arrow::temporal_conversions::timestamp_ns_to_datetime; +use arrow::{ + array::{Array, ArrayRef, GenericStringArray, PrimitiveArray, StringOffsetSizeTrait}, + datatypes::{ArrowPrimitiveType, DataType, TimestampNanosecondType}, +}; use arrow::{ array::{ - Array, GenericStringArray, PrimitiveArray, StringOffsetSizeTrait, - TimestampNanosecondArray, + Date32Array, Date64Array, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, }, - datatypes::{ArrowPrimitiveType, DataType, TimestampNanosecondType}, + compute::kernels::temporal, + datatypes::TimeUnit, + temporal_conversions::timestamp_ns_to_datetime, }; use chrono::prelude::*; use chrono::Duration; use chrono::LocalResult; -use super::ColumnarValue; - #[inline] /// Accepts a string in RFC3339 / ISO8601 standard format and some /// variants and converts it to a nanosecond precision timestamp. @@ -344,6 +347,98 @@ pub fn date_trunc(args: &[ColumnarValue]) -> Result { }) } +macro_rules! extract_date_part { + ($ARRAY: expr, $FN:expr) => { + match $ARRAY.data_type() { + DataType::Date32 => { + let array = $ARRAY.as_any().downcast_ref::().unwrap(); + Ok($FN(array)?) + } + DataType::Date64 => { + let array = $ARRAY.as_any().downcast_ref::().unwrap(); + Ok($FN(array)?) + } + DataType::Timestamp(time_unit, None) => match time_unit { + TimeUnit::Second => { + let array = $ARRAY + .as_any() + .downcast_ref::() + .unwrap(); + Ok($FN(array)?) + } + TimeUnit::Millisecond => { + let array = $ARRAY + .as_any() + .downcast_ref::() + .unwrap(); + Ok($FN(array)?) + } + TimeUnit::Microsecond => { + let array = $ARRAY + .as_any() + .downcast_ref::() + .unwrap(); + Ok($FN(array)?) + } + TimeUnit::Nanosecond => { + let array = $ARRAY + .as_any() + .downcast_ref::() + .unwrap(); + Ok($FN(array)?) + } + }, + datatype => Err(DataFusionError::Internal(format!( + "Extract does not support datatype {:?}", + datatype + ))), + } + }; +} + +/// DATE_PART SQL function +pub fn date_part(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Execution( + "Expected two arguments in DATE_PART".to_string(), + )); + } + let (date_part, array) = (&args[0], &args[1]); + + let date_part = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = date_part { + v + } else { + return Err(DataFusionError::Execution( + "First argument of `DATE_PART` must be non-null scalar Utf8".to_string(), + )); + }; + + let is_scalar = matches!(array, ColumnarValue::Scalar(_)); + + let array = match array { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => scalar.to_array(), + }; + + let arr = match date_part.to_lowercase().as_str() { + "hour" => extract_date_part!(array, temporal::hour), + "year" => extract_date_part!(array, temporal::year), + _ => Err(DataFusionError::Execution(format!( + "Date part '{}' not supported", + date_part + ))), + }?; + + Ok(if is_scalar { + ColumnarValue::Scalar(ScalarValue::try_from_array( + &(Arc::new(arr) as ArrayRef), + 0, + )?) + } else { + ColumnarValue::Array(Arc::new(arr)) + }) +} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/rust/datafusion/src/physical_plan/expressions/mod.rs b/rust/datafusion/src/physical_plan/expressions/mod.rs index bf47aa1cfe8..fe5fea1e2e4 100644 --- a/rust/datafusion/src/physical_plan/expressions/mod.rs +++ b/rust/datafusion/src/physical_plan/expressions/mod.rs @@ -58,7 +58,6 @@ pub use negative::{negative, NegativeExpr}; pub use not::{not, NotExpr}; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use sum::{sum_return_type, Sum}; - /// returns the name of the state pub fn format_state_name(name: &str, state_name: &str) -> String { format!("{}[{}]", name, state_name) diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index baacf949270..51941188bb4 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -71,6 +71,8 @@ pub enum Signature { Exact(Vec), /// fixed number of arguments of arbitrary types Any(usize), + /// One of a list of signatures + OneOf(Vec), } /// Scalar function @@ -138,6 +140,8 @@ pub enum BuiltinScalarFunction { NullIf, /// Date truncate DateTrunc, + /// Date part + DatePart, /// MD5 MD5, /// SHA224 @@ -192,6 +196,7 @@ impl FromStr for BuiltinScalarFunction { "upper" => BuiltinScalarFunction::Upper, "to_timestamp" => BuiltinScalarFunction::ToTimestamp, "date_trunc" => BuiltinScalarFunction::DateTrunc, + "date_part" => BuiltinScalarFunction::DatePart, "array" => BuiltinScalarFunction::Array, "nullif" => BuiltinScalarFunction::NullIf, "md5" => BuiltinScalarFunction::MD5, @@ -294,6 +299,7 @@ pub fn return_type( BuiltinScalarFunction::DateTrunc => { Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) } + BuiltinScalarFunction::DatePart => Ok(DataType::Int32), BuiltinScalarFunction::Array => Ok(DataType::FixedSizeList( Box::new(Field::new("item", arg_types[0].clone(), true)), arg_types.len() as i32, @@ -463,6 +469,7 @@ pub fn create_physical_expr( _ => unreachable!(), }, }, + BuiltinScalarFunction::DatePart => datetime_expressions::date_part, }); // coerce let args = coerce(args, input_schema, &signature(fun))?; @@ -507,6 +514,26 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { DataType::Utf8, DataType::Timestamp(TimeUnit::Nanosecond, None), ]), + BuiltinScalarFunction::DatePart => Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Date32]), + Signature::Exact(vec![DataType::Utf8, DataType::Date64]), + Signature::Exact(vec![ + DataType::Utf8, + DataType::Timestamp(TimeUnit::Second, None), + ]), + Signature::Exact(vec![ + DataType::Utf8, + DataType::Timestamp(TimeUnit::Microsecond, None), + ]), + Signature::Exact(vec![ + DataType::Utf8, + DataType::Timestamp(TimeUnit::Millisecond, None), + ]), + Signature::Exact(vec![ + DataType::Utf8, + DataType::Timestamp(TimeUnit::Nanosecond, None), + ]), + ]), BuiltinScalarFunction::Array => { Signature::Variadic(array_expressions::SUPPORTED_ARRAY_TYPES.to_vec()) } diff --git a/rust/datafusion/src/physical_plan/type_coercion.rs b/rust/datafusion/src/physical_plan/type_coercion.rs index a84707a48df..ae920cb870f 100644 --- a/rust/datafusion/src/physical_plan/type_coercion.rs +++ b/rust/datafusion/src/physical_plan/type_coercion.rs @@ -29,7 +29,7 @@ //! i64. However, i64 -> i32 is never performed as there are i64 //! values which can not be represented by i32 values. -use std::sync::Arc; +use std::{sync::Arc, vec}; use arrow::datatypes::{DataType, Schema, TimeUnit}; @@ -68,6 +68,32 @@ pub fn data_types( current_types: &[DataType], signature: &Signature, ) -> Result> { + let valid_types = get_valid_types(signature, current_types)?; + + if valid_types + .iter() + .any(|data_type| data_type == current_types) + { + return Ok(current_types.to_vec()); + } + + for valid_types in valid_types { + if let Some(types) = maybe_data_types(&valid_types, ¤t_types) { + return Ok(types); + } + } + + // none possible -> Error + Err(DataFusionError::Plan(format!( + "Coercion from {:?} to the signature {:?} failed.", + current_types, signature + ))) +} + +fn get_valid_types( + signature: &Signature, + current_types: &[DataType], +) -> Result>> { let valid_types = match signature { Signature::Variadic(valid_types) => valid_types .iter() @@ -95,23 +121,16 @@ pub fn data_types( } vec![(0..*number).map(|i| current_types[i].clone()).collect()] } - }; - - if valid_types.contains(¤t_types.to_owned()) { - return Ok(current_types.to_vec()); - } - - for valid_types in valid_types { - if let Some(types) = maybe_data_types(&valid_types, ¤t_types) { - return Ok(types); + Signature::OneOf(types) => { + let mut r = vec![]; + for s in types { + r.extend(get_valid_types(s, current_types)?); + } + r } - } + }; - // none possible -> Error - Err(DataFusionError::Plan(format!( - "Coercion from {:?} to the signature {:?} failed.", - current_types, signature - ))) + Ok(valid_types) } /// Try to coerce current_types into valid_types. diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index fc56052b29f..f985b506536 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -726,6 +726,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Value(Value::Boolean(n)) => Ok(lit(*n)), SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Utf8(None))), + SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction { + fun: functions::BuiltinScalarFunction::DatePart, + args: vec![ + Expr::Literal(ScalarValue::Utf8(Some(format!("{}", field)))), + self.sql_expr_to_logical_expr(expr)?, + ], + }), SQLExpr::Value(Value::Interval { value, diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index d5a278d9301..2f780b662b8 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1717,7 +1717,7 @@ fn make_timestamp_nano_table() -> Result> { } #[tokio::test] -async fn to_timstamp() -> Result<()> { +async fn to_timestamp() -> Result<()> { let mut ctx = ExecutionContext::new(); ctx.register_table("ts_data", make_timestamp_nano_table()?); @@ -2134,6 +2134,24 @@ async fn crypto_expressions() -> Result<()> { Ok(()) } +#[tokio::test] +async fn extract_date_part() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let sql = "SELECT + date_part('hour', CAST('2020-01-01' AS DATE)) AS hr1, + EXTRACT(HOUR FROM CAST('2020-01-01' AS DATE)) AS hr2, + EXTRACT(HOUR FROM to_timestamp('2020-09-08T12:00:00+00:00')) AS hr3, + date_part('YEAR', CAST('2000-01-01' AS DATE)) AS year1, + EXTRACT(year FROM to_timestamp('2020-09-08T12:00:00+00:00')) AS year2 + "; + + let actual = execute(&mut ctx, sql).await; + + let expected = vec![vec!["0", "0", "12", "2000", "2020"]]; + assert_eq!(expected, actual); + Ok(()) +} + #[tokio::test] async fn in_list_array() -> Result<()> { let mut ctx = ExecutionContext::new();