diff --git a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt index 8f1c006510f68..fee24740a6f19 100644 --- a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt +++ b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt @@ -52,31 +52,242 @@ SELECT arrow_typeof(1.0::float) Float32 # arrow_typeof_decimal -# query T -# SELECT arrow_typeof(1::Decimal) -# ---- -# Decimal128(38, 10) - -# # arrow_typeof_timestamp -# query T -# SELECT arrow_typeof(now()::timestamp) -# ---- -# Timestamp(Nanosecond, None) - -# # arrow_typeof_timestamp_utc -# query T -# SELECT arrow_typeof(now()) -# ---- -# Timestamp(Nanosecond, Some(\"+00:00\")) - -# # arrow_typeof_timestamp_date32( -# query T -# SELECT arrow_typeof(now()::date) -# ---- -# Date32 - -# # arrow_typeof_utf8 -# query T -# SELECT arrow_typeof('1') -# ---- -# Utf8 +query T +SELECT arrow_typeof(1::Decimal) +---- +Decimal128(38, 10) + +# arrow_typeof_timestamp +query T +SELECT arrow_typeof(now()::timestamp) +---- +Timestamp(Nanosecond, None) + +# arrow_typeof_timestamp_utc +query T +SELECT arrow_typeof(now()) +---- +Timestamp(Nanosecond, Some("+00:00")) + +# arrow_typeof_timestamp_date32( +query T +SELECT arrow_typeof(now()::date) +---- +Date32 + +# arrow_typeof_utf8 +query T +SELECT arrow_typeof('1') +---- +Utf8 + + +#### arrow_cast (in some ways opposite of arrow_typeof) + +# Basic tests + +query I +SELECT arrow_cast('1', 'Int16') +---- +1 + +# Basic error test +query error Error during planning: arrow_cast needs 2 arguments, 1 provided +SELECT arrow_cast('1') + +query error Error during planning: arrow_cast requires its second argument to be a constant string, got Int64\(43\) +SELECT arrow_cast('1', 43) + +query error Error unrecognized word: unknown +SELECT arrow_cast('1', 'unknown') + +# Round Trip tests: +query TTTTTTTTTTTTTTTTTTT +SELECT + arrow_typeof(arrow_cast(1, 'Int8')) as col_i8, + arrow_typeof(arrow_cast(1, 'Int16')) as col_i16, + arrow_typeof(arrow_cast(1, 'Int32')) as col_i32, + arrow_typeof(arrow_cast(1, 'Int64')) as col_i64, + arrow_typeof(arrow_cast(1, 'UInt8')) as col_u8, + arrow_typeof(arrow_cast(1, 'UInt16')) as col_u16, + arrow_typeof(arrow_cast(1, 'UInt32')) as col_u32, + arrow_typeof(arrow_cast(1, 'UInt64')) as col_u64, + -- can't seem to cast to Float16 for some reason + -- arrow_typeof(arrow_cast(1, 'Float16')) as col_f16, + arrow_typeof(arrow_cast(1, 'Float32')) as col_f32, + arrow_typeof(arrow_cast(1, 'Float64')) as col_f64, + arrow_typeof(arrow_cast('foo', 'Utf8')) as col_utf8, + arrow_typeof(arrow_cast('foo', 'LargeUtf8')) as col_large_utf8, + arrow_typeof(arrow_cast('foo', 'Binary')) as col_binary, + arrow_typeof(arrow_cast('foo', 'LargeBinary')) as col_large_binary, + arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)')) as col_ts_s, + arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Millisecond, None)')) as col_ts_ms, + arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Microsecond, None)')) as col_ts_us, + arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Nanosecond, None)')) as col_ts_ns, + arrow_typeof(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) as col_dict +---- +Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float32 Float64 Utf8 LargeUtf8 Binary LargeBinary Timestamp(Second, None) Timestamp(Millisecond, None) Timestamp(Microsecond, None) Timestamp(Nanosecond, None) Dictionary(Int32, Utf8) + + + + +## Basic Types: Create a table + +statement ok +create table foo as select + arrow_cast(1, 'Int8') as col_i8, + arrow_cast(1, 'Int16') as col_i16, + arrow_cast(1, 'Int32') as col_i32, + arrow_cast(1, 'Int64') as col_i64, + arrow_cast(1, 'UInt8') as col_u8, + arrow_cast(1, 'UInt16') as col_u16, + arrow_cast(1, 'UInt32') as col_u32, + arrow_cast(1, 'UInt64') as col_u64, + -- can't seem to cast to Float16 for some reason + -- arrow_cast(1.0, 'Float16') as col_f16, + arrow_cast(1.0, 'Float32') as col_f32, + arrow_cast(1.0, 'Float64') as col_f64 +; + +## Ensure each column in the table has the expected type + +query TTTTTTTTTT +SELECT + arrow_typeof(col_i8), + arrow_typeof(col_i16), + arrow_typeof(col_i32), + arrow_typeof(col_i64), + arrow_typeof(col_u8), + arrow_typeof(col_u16), + arrow_typeof(col_u32), + arrow_typeof(col_u64), + -- arrow_typeof(col_f16), + arrow_typeof(col_f32), + arrow_typeof(col_f64) + FROM foo; +---- +Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float32 Float64 + + +statement ok +drop table foo + +## Decimals: Create a table + +statement ok +create table foo as select + arrow_cast(100, 'Decimal128(3,2)') as col_d128 + -- Can't make a decimal 156: + -- This feature is not implemented: Can't create a scalar from array of type "Decimal256(3, 2)" + --arrow_cast(100, 'Decimal256(3,2)') as col_d256 +; + + +## Ensure each column in the table has the expected type + +query T +SELECT + arrow_typeof(col_d128) + -- arrow_typeof(col_d256), + FROM foo; +---- +Decimal128(3, 2) + + +statement ok +drop table foo + +## Strings, Binary: Create a table + +statement ok +create table foo as select + arrow_cast('foo', 'Utf8') as col_utf8, + arrow_cast('foo', 'LargeUtf8') as col_large_utf8, + arrow_cast('foo', 'Binary') as col_binary, + arrow_cast('foo', 'LargeBinary') as col_large_binary +; + +## Ensure each column in the table has the expected type + +query TTTT +SELECT + arrow_typeof(col_utf8), + arrow_typeof(col_large_utf8), + arrow_typeof(col_binary), + arrow_typeof(col_large_binary) + FROM foo; +---- +Utf8 LargeUtf8 Binary LargeBinary + + +statement ok +drop table foo + + +## Timestamps: Create a table + +statement ok +create table foo as select + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)') as col_ts_s, + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Millisecond, None)') as col_ts_ms, + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Microsecond, None)') as col_ts_us, + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Nanosecond, None)') as col_ts_ns +; + +## Ensure each column in the table has the expected type + +query TTTT +SELECT + arrow_typeof(col_ts_s), + arrow_typeof(col_ts_ms), + arrow_typeof(col_ts_us), + arrow_typeof(col_ts_ns) + FROM foo; +---- +Timestamp(Second, None) Timestamp(Millisecond, None) Timestamp(Microsecond, None) Timestamp(Nanosecond, None) + + +statement ok +drop table foo + +## Dictionaries + +statement ok +create table foo as select + arrow_cast('foo', 'Dictionary(Int32, Utf8)') as col_dict_int32_utf8, + arrow_cast('foo', 'Dictionary(Int8, LargeUtf8)') as col_dict_int8_largeutf8 +; + +## Ensure each column in the table has the expected type + +query TT +SELECT + arrow_typeof(col_dict_int32_utf8), + arrow_typeof(col_dict_int8_largeutf8) + FROM foo; +---- +Dictionary(Int32, Utf8) Dictionary(Int8, LargeUtf8) + + +statement ok +drop table foo + + +## Intervals: + +query error Cannot automatically convert Interval\(DayTime\) to Interval\(MonthDayNano\) +--- +select arrow_cast(interval '30 minutes', 'Interval(MonthDayNano)'); + +query error DataFusion error: Error during planning: Cannot automatically convert Utf8 to Interval\(MonthDayNano\) +select arrow_cast('30 minutes', 'Interval(MonthDayNano)'); + + +## Duration + +query error Cannot automatically convert Interval\(DayTime\) to Duration\(Second\) +--- +select arrow_cast(interval '30 minutes', 'Duration(Second)'); + +query error DataFusion error: Error during planning: Cannot automatically convert Utf8 to Duration\(Second\) +select arrow_cast('30 minutes', 'Duration(Second)'); diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 70612825989a1..802242b3e9599 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -2120,8 +2120,11 @@ mod roundtrip_tests { DataType::Float16, DataType::Float32, DataType::Float64, - // Add more timestamp tests + DataType::Timestamp(TimeUnit::Second, None), DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), DataType::Date32, DataType::Date64, DataType::Time32(TimeUnit::Second), diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs new file mode 100644 index 0000000000000..bc1313e2c114e --- /dev/null +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -0,0 +1,719 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Implementation of the `arrow_cast` function that allows +//! casting to arbitrary arrow types (rather than SQL types) + +use std::{fmt::Display, iter::Peekable, str::Chars}; + +use arrow_schema::{DataType, IntervalUnit, TimeUnit}; +use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; + +use datafusion_expr::{Expr, ExprSchemable}; + +pub const ARROW_CAST_NAME: &str = "arrow_cast"; + +/// Create an [`Expr`] that evaluates the `arrow_cast` function +/// +/// This function is not a [`BuiltInScalarFunction`] because the +/// return type of [`BuiltInScalarFunction`] depends only on the +/// *types* of the arguments. However, the type of `arrow_type` depends on +/// the *value* of its second argument. +/// +/// Use the `cast` function to cast to SQL type (which is then mapped +/// to the corresponding arrow type). For example to cast to `int` +/// (which is then mapped to the arrow type `Int32`) +/// +/// ```sql +/// select cast(column_x as int) ... +/// ``` +/// +/// Use the `arrow_cast` functiont to cast to a specfic arrow type +/// +/// For example +/// ```sql +/// select arrow_cast(column_x, 'Float64') +/// ``` +pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Plan(format!( + "arrow_cast needs 2 arguments, {} provided", + args.len() + ))); + } + let arg1 = args.pop().unwrap(); + let arg0 = args.pop().unwrap(); + + // arg1 must be a stirng + let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) = arg1 { + v + } else { + return Err(DataFusionError::Plan(format!( + "arrow_cast requires its second argument to be a constant string, got {arg1}" + ))); + }; + + // do the actual lookup to the appropriate data type + let data_type = parse_data_type(&data_type_string)?; + + arg0.cast_to(&data_type, schema) +} + +/// Parses `str` into a `DataType`. +/// +/// `parse_data_type` is the the reverse of [`DataType`]'s `Display` +/// impl, and maintains the invariant that +/// `parse_data_type(data_type.to_string()) == data_type` +/// +/// Example: +/// ``` +/// # use datafusion_sql::parse_data_type; +/// # use arrow_schema::DataType; +/// let display_value = "Int32"; +/// +/// // "Int32" is the Display value of `DataType` +/// assert_eq!(display_value, &format!("{}", DataType::Int32)); +/// +/// // parse_data_type coverts "Int32" back to `DataType`: +/// let data_type = parse_data_type(display_value).unwrap(); +/// assert_eq!(data_type, DataType::Int32); +/// ``` +/// +/// Remove if added to arrow: https://github.com/apache/arrow-rs/issues/3821 +pub fn parse_data_type(val: &str) -> Result { + Parser::new(val).parse() +} + +fn make_error(val: &str, msg: &str) -> DataFusionError { + DataFusionError::Plan( + format!("Unsupported type '{val}'. Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'. Error {msg}" ) + ) +} + +fn make_error_expected(val: &str, expected: &Token, actual: &Token) -> DataFusionError { + make_error(val, &format!("Expected '{expected}', got '{actual}'")) +} + +#[derive(Debug)] +/// Implementation of `parse_data_type`, modeled after +struct Parser<'a> { + val: &'a str, + tokenizer: Tokenizer<'a>, +} + +impl<'a> Parser<'a> { + fn new(val: &'a str) -> Self { + Self { + val, + tokenizer: Tokenizer::new(val), + } + } + + fn parse(mut self) -> Result { + let data_type = self.parse_next_type()?; + // ensure that there is no trailing content + if self.tokenizer.next().is_some() { + return Err(make_error( + self.val, + &format!("checking trailing content after parsing '{data_type}'"), + )); + } else { + Ok(data_type) + } + } + + /// parses the next full DataType + fn parse_next_type(&mut self) -> Result { + match self.next_token()? { + Token::SimpleType(data_type) => Ok(data_type), + Token::Timestamp => self.parse_timestamp(), + Token::Time32 => self.parse_time32(), + Token::Time64 => self.parse_time64(), + Token::Duration => self.parse_duration(), + Token::Interval => self.parse_interval(), + Token::FixedSizeBinary => self.parse_fixed_size_binary(), + Token::Decimal128 => self.parse_decimal_128(), + Token::Decimal256 => self.parse_decimal_256(), + Token::Dictionary => self.parse_dictionary(), + tok => Err(make_error( + self.val, + &format!("finding next type, got unexpected '{tok}'"), + )), + } + } + + /// Parses the next timeunit + fn parse_time_unit(&mut self, context: &str) -> Result { + match self.next_token()? { + Token::TimeUnit(time_unit) => Ok(time_unit), + tok => Err(make_error( + self.val, + &format!("finding TimeUnit for {context}, got {tok}"), + )), + } + } + + /// Parses the next integer value + fn parse_i64(&mut self, context: &str) -> Result { + match self.next_token()? { + Token::Integer(v) => Ok(v), + tok => Err(make_error( + self.val, + &format!("finding i64 for {context}, got '{tok}'"), + )), + } + } + + /// Parses the next i32 integer value + fn parse_i32(&mut self, context: &str) -> Result { + let length = self.parse_i64(context)?; + length.try_into().map_err(|e| { + make_error( + self.val, + &format!("converting {length} into i32 for {context}: {e}"), + ) + }) + } + + /// Parses the next i8 integer value + fn parse_i8(&mut self, context: &str) -> Result { + let length = self.parse_i64(context)?; + length.try_into().map_err(|e| { + make_error( + self.val, + &format!("converting {length} into i8 for {context}: {e}"), + ) + }) + } + + /// Parses the next u8 integer value + fn parse_u8(&mut self, context: &str) -> Result { + let length = self.parse_i64(context)?; + length.try_into().map_err(|e| { + make_error( + self.val, + &format!("converting {length} into u8 for {context}: {e}"), + ) + }) + } + + /// Parses the next timestamp (called after `Timestamp` has been consumed) + fn parse_timestamp(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let time_unit = self.parse_time_unit("Timestamp")?; + self.expect_token(Token::Comma)?; + // TODO Support timezones other than None + self.expect_token(Token::None)?; + let timezone = None; + + self.expect_token(Token::RParen)?; + Ok(DataType::Timestamp(time_unit, timezone)) + } + + /// Parses the next Time32 (called after `Time32` has been consumed) + fn parse_time32(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let time_unit = self.parse_time_unit("Time32")?; + self.expect_token(Token::RParen)?; + Ok(DataType::Time32(time_unit)) + } + + /// Parses the next Time64 (called after `Time64` has been consumed) + fn parse_time64(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let time_unit = self.parse_time_unit("Time64")?; + self.expect_token(Token::RParen)?; + Ok(DataType::Time64(time_unit)) + } + + /// Parses the next Duration (called after `Duration` has been consumed) + fn parse_duration(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let time_unit = self.parse_time_unit("Duration")?; + self.expect_token(Token::RParen)?; + Ok(DataType::Duration(time_unit)) + } + + /// Parses the next Interval (called after `Interval` has been consumed) + fn parse_interval(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let interval_unit = match self.next_token()? { + Token::IntervalUnit(interval_unit) => interval_unit, + tok => { + return Err(make_error( + self.val, + &format!("finding IntervalUnit for Interval, got {tok}"), + )) + } + }; + self.expect_token(Token::RParen)?; + Ok(DataType::Interval(interval_unit)) + } + + /// Parses the next FixedSizeBinary (called after `FixedSizeBinary` has been consumed) + fn parse_fixed_size_binary(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let length = self.parse_i32("FixedSizeBinary")?; + self.expect_token(Token::RParen)?; + Ok(DataType::FixedSizeBinary(length)) + } + + /// Parses the next Decimal128 (called after `Decimal128` has been consumed) + fn parse_decimal_128(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let precision = self.parse_u8("Decimal128")?; + self.expect_token(Token::Comma)?; + let scale = self.parse_i8("Decimal128")?; + self.expect_token(Token::RParen)?; + Ok(DataType::Decimal128(precision, scale)) + } + + /// Parses the next Decimal256 (called after `Decimal256` has been consumed) + fn parse_decimal_256(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let precision = self.parse_u8("Decimal256")?; + self.expect_token(Token::Comma)?; + let scale = self.parse_i8("Decimal256")?; + self.expect_token(Token::RParen)?; + Ok(DataType::Decimal256(precision, scale)) + } + + /// Parses the next Dictionary (called after `Dictionary` has been consumed) + fn parse_dictionary(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let key_type = self.parse_next_type()?; + self.expect_token(Token::Comma)?; + let value_type = self.parse_next_type()?; + self.expect_token(Token::RParen)?; + Ok(DataType::Dictionary( + Box::new(key_type), + Box::new(value_type), + )) + } + + /// return the next token, or an error if there are none left + fn next_token(&mut self) -> Result { + match self.tokenizer.next() { + None => Err(make_error(self.val, "finding next token")), + Some(token) => token, + } + } + + /// consume the next token, returning OK(()) if it matches tok, and Err if not + fn expect_token(&mut self, tok: Token) -> Result<()> { + let next_token = self.next_token()?; + if next_token == tok { + Ok(()) + } else { + Err(make_error_expected(self.val, &tok, &next_token)) + } + } +} + +/// returns true if this character is a separator +fn is_separator(c: char) -> bool { + c == '(' || c == ')' || c == ',' || c == ' ' +} + +#[derive(Debug)] +/// Splits a strings like Dictionary(Int32, Int64) into tokens sutable for parsing +/// +/// For example the string "Timestamp(Nanosecond, None)" would be parsed into: +/// +/// * Token::Timestamp +/// * Token::Lparen +/// * Token::IntervalUnit(IntervalUnit::Nanosecond) +/// * Token::Comma, +/// * Token::None, +/// * Token::Rparen, +struct Tokenizer<'a> { + val: &'a str, + chars: Peekable>, + // temporary buffer for parsing words + word: String, +} + +impl<'a> Tokenizer<'a> { + fn new(val: &'a str) -> Self { + Self { + val, + chars: val.chars().peekable(), + word: String::new(), + } + } + + /// returns the next char, without consuming it + fn peek_next_char(&mut self) -> Option { + self.chars.peek().copied() + } + + /// returns the next char, and consuming it + fn next_char(&mut self) -> Option { + self.chars.next() + } + + /// parse the characters in val starting at pos, until the next + /// `,`, `(`, or `)` or end of line + fn parse_word(&mut self) -> Result { + // reset temp space + self.word.clear(); + loop { + match self.peek_next_char() { + None => break, + Some(c) if is_separator(c) => break, + Some(c) => { + self.next_char(); + self.word.push(c); + } + } + } + + // if it started with a number, try parsing it as an integer + if let Some(c) = self.word.chars().next() { + if c == '-' || c.is_numeric() { + let val: i64 = self.word.parse().map_err(|e| { + make_error( + self.val, + &format!("parsing {} as integer: {e}", self.word), + ) + })?; + return Ok(Token::Integer(val)); + } + } + + // figure out what the word was + let token = match self.word.as_str() { + "Null" => Token::SimpleType(DataType::Null), + "Boolean" => Token::SimpleType(DataType::Boolean), + + "Int8" => Token::SimpleType(DataType::Int8), + "Int16" => Token::SimpleType(DataType::Int16), + "Int32" => Token::SimpleType(DataType::Int32), + "Int64" => Token::SimpleType(DataType::Int64), + + "UInt8" => Token::SimpleType(DataType::UInt8), + "UInt16" => Token::SimpleType(DataType::UInt16), + "UInt32" => Token::SimpleType(DataType::UInt32), + "UInt64" => Token::SimpleType(DataType::UInt64), + + "Utf8" => Token::SimpleType(DataType::Utf8), + "LargeUtf8" => Token::SimpleType(DataType::LargeUtf8), + "Binary" => Token::SimpleType(DataType::Binary), + "LargeBinary" => Token::SimpleType(DataType::LargeBinary), + + "Float16" => Token::SimpleType(DataType::Float16), + "Float32" => Token::SimpleType(DataType::Float32), + "Float64" => Token::SimpleType(DataType::Float64), + + "Date32" => Token::SimpleType(DataType::Date32), + "Date64" => Token::SimpleType(DataType::Date64), + + "Second" => Token::TimeUnit(TimeUnit::Second), + "Millisecond" => Token::TimeUnit(TimeUnit::Millisecond), + "Microsecond" => Token::TimeUnit(TimeUnit::Microsecond), + "Nanosecond" => Token::TimeUnit(TimeUnit::Nanosecond), + + "Timestamp" => Token::Timestamp, + "Time32" => Token::Time32, + "Time64" => Token::Time64, + "Duration" => Token::Duration, + "Interval" => Token::Interval, + "Dictionary" => Token::Dictionary, + + "FixedSizeBinary" => Token::FixedSizeBinary, + "Decimal128" => Token::Decimal128, + "Decimal256" => Token::Decimal256, + + "YearMonth" => Token::IntervalUnit(IntervalUnit::YearMonth), + "DayTime" => Token::IntervalUnit(IntervalUnit::DayTime), + "MonthDayNano" => Token::IntervalUnit(IntervalUnit::MonthDayNano), + + "None" => Token::None, + + _ => { + return Err(make_error( + self.val, + &format!("unrecognized word: {}", self.word), + )) + } + }; + Ok(token) + } +} + +impl<'a> Iterator for Tokenizer<'a> { + type Item = Result; + + fn next(&mut self) -> Option { + loop { + match self.peek_next_char()? { + ' ' => { + // skip whitespace + self.next_char(); + continue; + } + '(' => { + self.next_char(); + return Some(Ok(Token::LParen)); + } + ')' => { + self.next_char(); + return Some(Ok(Token::RParen)); + } + ',' => { + self.next_char(); + return Some(Ok(Token::Comma)); + } + _ => return Some(self.parse_word()), + } + } + } +} + +/// Grammar is +/// +#[derive(Debug, PartialEq)] +enum Token { + // Null, or Int32 + SimpleType(DataType), + Timestamp, + Time32, + Time64, + Duration, + Interval, + FixedSizeBinary, + Decimal128, + Decimal256, + Dictionary, + TimeUnit(TimeUnit), + IntervalUnit(IntervalUnit), + LParen, + RParen, + Comma, + None, + Integer(i64), +} + +impl Display for Token { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Token::SimpleType(t) => write!(f, "{t}"), + Token::Timestamp => write!(f, "Timestamp"), + Token::Time32 => write!(f, "Time32"), + Token::Time64 => write!(f, "Time64"), + Token::Duration => write!(f, "Duration"), + Token::Interval => write!(f, "Interval"), + Token::TimeUnit(u) => write!(f, "TimeUnit({u:?})"), + Token::IntervalUnit(u) => write!(f, "IntervalUnit({u:?})"), + Token::LParen => write!(f, "("), + Token::RParen => write!(f, ")"), + Token::Comma => write!(f, ","), + Token::None => write!(f, "None"), + Token::FixedSizeBinary => write!(f, "FixedSizeBinary"), + Token::Decimal128 => write!(f, "Decimal128"), + Token::Decimal256 => write!(f, "Decimal256"), + Token::Dictionary => write!(f, "Dictionary"), + Token::Integer(v) => write!(f, "Integer({v})"), + } + } +} + +#[cfg(test)] +mod test { + use arrow_schema::{IntervalUnit, TimeUnit}; + + use super::*; + + #[test] + fn test_parse_data_type() { + // this ensures types can be parsed correctly from their string representations + for dt in list_datatypes() { + round_trip(dt) + } + } + + /// convert data_type to a string, and then parse it as a type + /// verifying it is the same + fn round_trip(data_type: DataType) { + let data_type_string = data_type.to_string(); + println!("Input '{data_type_string}' ({data_type:?})"); + let parsed_type = parse_data_type(&data_type_string).unwrap(); + assert_eq!( + data_type, parsed_type, + "Mismatch parsing {data_type_string}" + ); + } + + fn list_datatypes() -> Vec { + vec![ + // --------- + // Non Nested types + // --------- + DataType::Null, + DataType::Boolean, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Float16, + DataType::Float32, + DataType::Float64, + DataType::Timestamp(TimeUnit::Second, None), + DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, None), + // TODO support timezones + //DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), + DataType::Date32, + DataType::Date64, + DataType::Time32(TimeUnit::Second), + DataType::Time32(TimeUnit::Millisecond), + DataType::Time32(TimeUnit::Microsecond), + DataType::Time32(TimeUnit::Nanosecond), + DataType::Time64(TimeUnit::Second), + DataType::Time64(TimeUnit::Millisecond), + DataType::Time64(TimeUnit::Microsecond), + DataType::Time64(TimeUnit::Nanosecond), + DataType::Duration(TimeUnit::Second), + DataType::Duration(TimeUnit::Millisecond), + DataType::Duration(TimeUnit::Microsecond), + DataType::Duration(TimeUnit::Nanosecond), + DataType::Interval(IntervalUnit::YearMonth), + DataType::Interval(IntervalUnit::DayTime), + DataType::Interval(IntervalUnit::MonthDayNano), + DataType::Binary, + DataType::FixedSizeBinary(0), + DataType::FixedSizeBinary(1234), + DataType::FixedSizeBinary(-432), + DataType::LargeBinary, + DataType::Utf8, + DataType::LargeUtf8, + DataType::Decimal128(7, 12), + DataType::Decimal256(6, 13), + // --------- + // Nested types + // --------- + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Timestamp(TimeUnit::Nanosecond, None)), + ), + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::FixedSizeBinary(23)), + ), + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new( + // nested dictionaries are probably a bad idea but they are possible + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Utf8), + ), + ), + ), + // TODO support more structured types (List, LargeList, Struct, Union, Map, RunEndEncoded, etc) + ] + } + + #[test] + fn test_parse_data_type_whitespace_tolerance() { + // (string to parse, expected DataType) + let cases = [ + ("Int8", DataType::Int8), + ( + "Timestamp (Nanosecond, None)", + DataType::Timestamp(TimeUnit::Nanosecond, None), + ), + ( + "Timestamp (Nanosecond, None) ", + DataType::Timestamp(TimeUnit::Nanosecond, None), + ), + ( + " Timestamp (Nanosecond, None )", + DataType::Timestamp(TimeUnit::Nanosecond, None), + ), + ( + "Timestamp (Nanosecond, None ) ", + DataType::Timestamp(TimeUnit::Nanosecond, None), + ), + ]; + + for (data_type_string, expected_data_type) in cases { + println!("Parsing '{data_type_string}', expecting '{expected_data_type:?}'"); + let parsed_data_type = parse_data_type(data_type_string).unwrap(); + assert_eq!(parsed_data_type, expected_data_type); + } + } + + #[test] + fn parse_data_type_errors() { + // (string to parse, expected error message) + let cases = [ + ("", "Unsupported type ''"), + ("", "Error finding next token"), + ("null", "Unsupported type 'null'"), + ("Nu", "Unsupported type 'Nu'"), + // TODO support timezones + ( + r#"Timestamp(Nanosecond, Some("UTC"))"#, + "Error unrecognized word: Some", + ), + ("Timestamp(Nanosecond, ", "Error finding next token"), + ( + "Float32 Float32", + "trailing content after parsing 'Float32'", + ), + ("Int32, ", "trailing content after parsing 'Int32'"), + ("Int32(3), ", "trailing content after parsing 'Int32'"), + ("FixedSizeBinary(Int32), ", "Error finding i64 for FixedSizeBinary, got 'Int32'"), + ("FixedSizeBinary(3.0), ", "Error parsing 3.0 as integer: invalid digit found in string"), + // too large for i32 + ("FixedSizeBinary(4000000000), ", "Error converting 4000000000 into i32 for FixedSizeBinary: out of range integral type conversion attempted"), + // can't have negative precision + ("Decimal128(-3, 5)", "Error converting -3 into u8 for Decimal128: out of range integral type conversion attempted"), + ("Decimal256(-3, 5)", "Error converting -3 into u8 for Decimal256: out of range integral type conversion attempted"), + ("Decimal128(3, 500)", "Error converting 500 into i8 for Decimal128: out of range integral type conversion attempted"), + ("Decimal256(3, 500)", "Error converting 500 into i8 for Decimal256: out of range integral type conversion attempted"), + + ]; + + for (data_type_string, expected_message) in cases { + print!("Parsing '{data_type_string}', expecting '{expected_message}'"); + match parse_data_type(data_type_string) { + Ok(d) => panic!( + "Expected error while parsing '{data_type_string}', but got '{d}'" + ), + Err(e) => { + let message = e.to_string(); + assert!( + message.contains(expected_message), + "\n\ndid not find expected in actual.\n\nexpected: {expected_message}\nactual:{message}\n" + ); + // errors should also contain a help message + assert!(message.contains("Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'")); + } + } + println!(" Ok"); + } + } +} diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index c5f23213aa31c..68a5df054b69a 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -29,6 +29,8 @@ use sqlparser::ast::{ }; use std::str::FromStr; +use super::arrow_cast::ARROW_CAST_NAME; + impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn sql_function_to_expr( &self, @@ -110,24 +112,29 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; // finally, user-defined functions (UDF) and UDAF - match self.schema_provider.get_function_meta(&name) { - Some(fm) => { - let args = self.function_args_to_expr(function.args, schema)?; + if let Some(fm) = self.schema_provider.get_function_meta(&name) { + let args = self.function_args_to_expr(function.args, schema)?; + return Ok(Expr::ScalarUDF { fun: fm, args }); + } - Ok(Expr::ScalarUDF { fun: fm, args }) - } - None => match self.schema_provider.get_aggregate_meta(&name) { - Some(fm) => { - let args = self.function_args_to_expr(function.args, schema)?; - Ok(Expr::AggregateUDF { - fun: fm, - args, - filter: None, - }) - } - _ => Err(DataFusionError::Plan(format!("Invalid function '{name}'"))), - }, + // User defined aggregate functions + if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) { + let args = self.function_args_to_expr(function.args, schema)?; + return Ok(Expr::AggregateUDF { + fun: fm, + args, + filter: None, + }); } + + // Special case arrow_cast (as its type is dependent on its argument value) + if name == ARROW_CAST_NAME { + let args = self.function_args_to_expr(function.args, schema)?; + return super::arrow_cast::create_arrow_cast(args, schema); + } + + // Could not find the relevant function, so return an error + Err(DataFusionError::Plan(format!("Invalid function '{name}'"))) } pub(super) fn sql_named_function_to_expr( diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index f226924516ef8..ad05fbcc16ca8 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub(crate) mod arrow_cast; mod binary_op; mod function; mod grouping_set; diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index efe239f458111..c0c1a4ac91186 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -30,4 +30,5 @@ pub mod utils; mod values; pub use datafusion_common::{ResolvedTableReference, TableReference}; +pub use expr::arrow_cast::parse_data_type; pub use sqlparser; diff --git a/datafusion/sql/tests/integration_test.rs b/datafusion/sql/tests/integration_test.rs index 71f5bf05e99e6..66095990701c7 100644 --- a/datafusion/sql/tests/integration_test.rs +++ b/datafusion/sql/tests/integration_test.rs @@ -2311,6 +2311,15 @@ fn approx_median_window() { quick_test(sql, expected); } +#[test] +fn select_arrow_cast() { + let sql = "SELECT arrow_cast(1234, 'Float64'), arrow_cast('foo', 'LargeUtf8')"; + let expected = "\ + Projection: CAST(Int64(1234) AS Float64), CAST(Utf8(\"foo\") AS LargeUtf8)\ + \n EmptyRelation"; + quick_test(sql, expected); +} + #[test] fn select_typed_date_string() { let sql = "SELECT date '2020-12-10' AS date"; @@ -2534,7 +2543,7 @@ impl ContextProvider for MockContextProvider { } fn get_function_meta(&self, _name: &str) -> Option> { - unimplemented!() + None } fn get_aggregate_meta(&self, name: &str) -> Option> { diff --git a/docs/source/user-guide/sql/data_types.md b/docs/source/user-guide/sql/data_types.md index 968dcda53281d..9f0ca8f89467b 100644 --- a/docs/source/user-guide/sql/data_types.md +++ b/docs/source/user-guide/sql/data_types.md @@ -37,6 +37,18 @@ the `arrow_typeof` function. For example: +-------------------------------------+ ``` +You can cast a SQL expression to a specific Arrow type using the `arrow_cast` function +For example, to cast the output of `now()` to a `Timestamp` with second precision rather: + +```sql +❯ select arrow_cast(now(), 'Timestamp(Second, None)'); ++---------------------+ +| now() | ++---------------------+ +| 2023-03-03T17:19:21 | ++---------------------+ +``` + ## Character Types | SQL DataType | Arrow DataType | @@ -65,12 +77,12 @@ the `arrow_typeof` function. For example: ## Date/Time Types -| SQL DataType | Arrow DataType | -| ------------ | :----------------------------------------------------------------------- | -| `DATE` | `Date32` | -| `TIME` | `Time64(TimeUnit::Nanosecond)` | -| `TIMESTAMP` | `Timestamp(TimeUnit::Nanosecond, None)` | -| `INTERVAL` | `Interval(IntervalUnit::YearMonth)` or `Interval(IntervalUnit::DayTime)` | +| SQL DataType | Arrow DataType | +| ------------ | :---------------------------------------------- | +| `DATE` | `Date32` | +| `TIME` | `Time64(Nanosecond)` | +| `TIMESTAMP` | `Timestamp(Nanosecond, None)` | +| `INTERVAL` | `Interval(IntervalUnit)` or `Interval(DayTime)` | ## Boolean Types @@ -84,7 +96,7 @@ the `arrow_typeof` function. For example: | ------------ | :------------- | | `BYTEA` | `Binary` | -## Unsupported Types +## Unsupported SQL Types | SQL Data Type | Arrow DataType | | ------------- | :------------------ | @@ -100,3 +112,43 @@ the `arrow_typeof` function. For example: | `ENUM` | _Not yet supported_ | | `SET` | _Not yet supported_ | | `DATETIME` | _Not yet supported_ | + +## Supported Arrow Types + +The following types are supported by the `arrow_typeof` function: + +| Arrow Type | +| ----------------------------------------------------------- | +| `Null` | +| `Boolean` | +| `Int8` | +| `Int16` | +| `Int32` | +| `Int64` | +| `UInt8` | +| `UInt16` | +| `UInt32` | +| `UInt64` | +| `Float16` | +| `Float32` | +| `Float64` | +| `Utf8` | +| `LargeUtf8` | +| `Binary` | +| `Timestamp(Second, None)` | +| `Timestamp(Millisecond, None)` | +| `Timestamp(Microsecond, None)` | +| `Timestamp(Nanosecond, None)` | +| `Time32` | +| `Time64` | +| `Duration(Second)` | +| `Duration(Millisecond)` | +| `Duration(Microsecond)` | +| `Duration(Nanosecond)` | +| `Interval(YearMonth)` | +| `Interval(DayTime)` | +| `Interval(MonthDayNano)` | +| `Interval(MonthDayNano)` | +| `FixedSizeBinary()` (e.g. `FixedSizeBinary(16)`) | +| `Decimal128(, )` e.g. `Decimal128(3, 10)` | +| `Decimal256(, )` e.g. `Decimal256(3, 10)` |