From 52e5c5826b013e1d76921c93e6bd199162ef937d Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 2 Feb 2023 16:23:20 -0500 Subject: [PATCH 01/11] Add `arrow_cast` function --- .../sqllogictests/test_files/arrow_typeof.slt | 228 +++++- datafusion/proto/src/logical_plan/mod.rs | 5 +- datafusion/sql/src/expr/arrow_cast.rs | 673 ++++++++++++++++++ datafusion/sql/src/expr/function.rs | 39 +- datafusion/sql/src/expr/mod.rs | 1 + datafusion/sql/src/lib.rs | 1 + datafusion/sql/tests/integration_test.rs | 11 +- docs/source/user-guide/sql/data_types.md | 60 +- 8 files changed, 968 insertions(+), 50 deletions(-) create mode 100644 datafusion/sql/src/expr/arrow_cast.rs diff --git a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt index 8f1c006510f68..72031471bac64 100644 --- a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt +++ b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt @@ -52,31 +52,203 @@ SELECT arrow_typeof(1.0::float) Float32 # arrow_typeof_decimal -# query T -# SELECT arrow_typeof(1::Decimal) -# ---- -# Decimal128(38, 10) - -# # arrow_typeof_timestamp -# query T -# SELECT arrow_typeof(now()::timestamp) -# ---- -# Timestamp(Nanosecond, None) - -# # arrow_typeof_timestamp_utc -# query T -# SELECT arrow_typeof(now()) -# ---- -# Timestamp(Nanosecond, Some(\"+00:00\")) - -# # arrow_typeof_timestamp_date32( -# query T -# SELECT arrow_typeof(now()::date) -# ---- -# Date32 - -# # arrow_typeof_utf8 -# query T -# SELECT arrow_typeof('1') -# ---- -# Utf8 +query T +SELECT arrow_typeof(1::Decimal) +---- +Decimal128(38, 10) + +# arrow_typeof_timestamp +query T +SELECT arrow_typeof(now()::timestamp) +---- +Timestamp(Nanosecond, None) + +# arrow_typeof_timestamp_utc +query T +SELECT arrow_typeof(now()) +---- +Timestamp(Nanosecond, Some("+00:00")) + +# arrow_typeof_timestamp_date32( +query T +SELECT arrow_typeof(now()::date) +---- +Date32 + +# arrow_typeof_utf8 +query T +SELECT arrow_typeof('1') +---- +Utf8 + + +#### arrow_cast (in some ways opposite of arrow_typeof) + + +query I +SELECT arrow_cast('1', 'Int16') +---- +1 + +query error Error during planning: arrow_cast needs 2 arguments, 1 provided +SELECT arrow_cast('1') + +query error Error during planning: arrow_cast requires its second argument to be a constant string, got Int64\(43\) +SELECT arrow_cast('1', 43) + +query error Error unrecognized word: unknown +SELECT arrow_cast('1', 'unknown') + + +## Basic types + +statement ok +create table foo as select + arrow_cast(1, 'Int8') as col_i8, + arrow_cast(1, 'Int16') as col_i16, + arrow_cast(1, 'Int32') as col_i32, + arrow_cast(1, 'Int64') as col_i64, + arrow_cast(1, 'UInt8') as col_u8, + arrow_cast(1, 'UInt16') as col_u16, + arrow_cast(1, 'UInt32') as col_u32, + arrow_cast(1, 'UInt64') as col_u64, + -- can't seem to cast to Float16 for some reason + arrow_cast(1.0, 'Float32') as col_f32, + arrow_cast(1.0, 'Float64') as col_f64 +; + + +query TTTTTTTTTT +SELECT + arrow_typeof(col_i8), + arrow_typeof(col_i16), + arrow_typeof(col_i32), + arrow_typeof(col_i64), + arrow_typeof(col_u8), + arrow_typeof(col_u16), + arrow_typeof(col_u32), + arrow_typeof(col_u64), + arrow_typeof(col_f32), + arrow_typeof(col_f64) + FROM foo; +---- +Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float32 Float64 + + + +statement ok +drop table foo + +## Decimals + +statement ok +create table foo as select + arrow_cast(100, 'Decimal128(3,2)') as col_d128 + -- Can't make a decimal 156: + -- This feature is not implemented: Can't create a scalar from array of type "Decimal256(3, 2)" + --arrow_cast(100, 'Decimal256(3,2)') as col_d256 +; + + +query T +SELECT + arrow_typeof(col_d128) + -- arrow_typeof(col_d256), + FROM foo; +---- +Decimal128(3, 2) + + +statement ok +drop table foo + +## strings, large strings + +statement ok +create table foo as select + arrow_cast('foo', 'Utf8') as col_utf8, + arrow_cast('foo', 'LargeUtf8') as col_large_utf8, + arrow_cast('foo', 'Binary') as col_binary, + arrow_cast('foo', 'LargeBinary') as col_large_binary +; + + +query TTTT +SELECT + arrow_typeof(col_utf8), + arrow_typeof(col_large_utf8), + arrow_typeof(col_binary), + arrow_typeof(col_large_binary) + FROM foo; +---- +Utf8 LargeUtf8 Binary LargeBinary + + +statement ok +drop table foo + + +## timestamps + +statement ok +create table foo as select + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)') as col_ts_s, + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Millisecond, None)') as col_ts_ms, + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Microsecond, None)') as col_ts_us, + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Nanosecond, None)') as col_ts_ns +; + + +query TTTT +SELECT + arrow_typeof(col_ts_s), + arrow_typeof(col_ts_ms), + arrow_typeof(col_ts_us), + arrow_typeof(col_ts_ns) + FROM foo; +---- +Timestamp(Second, None) Timestamp(Millisecond, None) Timestamp(Microsecond, None) Timestamp(Nanosecond, None) + + +statement ok +drop table foo + + +## durations + +statement ok +create table foo as select + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)') as col_ts_s, + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Millisecond, None)') as col_ts_ms, + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Microsecond, None)') as col_ts_us, + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Nanosecond, None)') as col_ts_ns +; + + +query TTTT +SELECT + arrow_typeof(col_ts_s), + arrow_typeof(col_ts_ms), + arrow_typeof(col_ts_us), + arrow_typeof(col_ts_ns) + FROM foo; +---- +Timestamp(Second, None) Timestamp(Millisecond, None) Timestamp(Microsecond, None) Timestamp(Nanosecond, None) + + +statement ok +drop table foo + + +## intervals + +query error Cannot automatically convert Interval\(DayTime\) to Interval\(MonthDayNano\) +--- +select arrow_cast(interval '30 minutes', 'Interval(MonthDayNano)'); + + +## duration + +query error Cannot automatically convert Interval\(DayTime\) to Duration\(Second\) +--- +select arrow_cast(interval '30 minutes', 'Duration(Second)'); diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 2f4f83d89d7e4..ca8e3dd091c7c 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -2113,8 +2113,11 @@ mod roundtrip_tests { DataType::Float16, DataType::Float32, DataType::Float64, - // Add more timestamp tests + DataType::Timestamp(TimeUnit::Second, None), DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), DataType::Date32, DataType::Date64, DataType::Time32(TimeUnit::Second), diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs new file mode 100644 index 0000000000000..806ee251826c0 --- /dev/null +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -0,0 +1,673 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Implementation of the `arrow_cast` function that allows +//! casting to arbitrary arrow types (rather than SQL types) + +use std::{fmt::Display, iter::Peekable, str::Chars}; + +use arrow_schema::{DataType, IntervalUnit, TimeUnit}; +use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; + +use datafusion_expr::{Expr, ExprSchemable}; + +pub const ARROW_CAST_NAME: &str = "arrow_cast"; + +/// Create an [`Expr`] that evaluates the `arrow_cast` function +/// +/// This function is not a [`BuiltInScalarFunction`] because the +/// return type of [`BuiltInScalarFunction`] depends only on the +/// *types* of the arguments. However, the type of `arrow_type` depends on +/// the *value* of its second argument. +/// +/// Use the `cast` function to cast to SQL type (which is then mapped +/// to the corresponding arrow type). For example to cast to `int` +/// (which is then mapped to the arrow type `Int32`) +/// +/// ```sql +/// select cast(column_x as int) ... +/// ``` +/// +/// Use the `arrow_cast` functiont to cast to a specfic arrow type +/// +/// For example +/// ```sql +/// select arrow_cast(column_x, 'Float64') +/// ``` +pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Plan(format!( + "arrow_cast needs 2 arguments, {} provided", + args.len() + ))); + } + let arg1 = args.pop().unwrap(); + let arg0 = args.pop().unwrap(); + + // arg1 must be a stirng + let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) = arg1 { + v + } else { + return Err(DataFusionError::Plan(format!( + "arrow_cast requires its second argument to be a constant string, got {arg1}" + ))); + }; + + // do the actual lookup to the appropriate data type + let data_type = parse_data_type(&data_type_string)?; + + arg0.cast_to(&data_type, schema) +} + +/// Parses `str` into a `DataType`. +/// +/// `parse_data_type` is the the reverse of [`DataType`]'s `Display` +/// impl, and maintains the invariant that +/// `parse_data_type(data_type.to_string()) == data_type` +/// +/// Example: +/// ``` +/// # use datafusion_sql::parse_data_type; +/// # use arrow_schema::DataType; +/// let display_value = "Int32"; +/// +/// // "Int32" is the Display value of `DataType` +/// assert_eq!(display_value, &format!("{}", DataType::Int32)); +/// +/// // parse_data_type coverts "Int32" back to `DataType`: +/// let data_type = parse_data_type(display_value).unwrap(); +/// assert_eq!(data_type, DataType::Int32); +/// ``` +/// +/// TODO file a ticket about bringing this into arrow possibly +pub fn parse_data_type(val: &str) -> Result { + Parser::new(val).parse() +} + +fn make_error(val: &str, msg: &str) -> DataFusionError { + DataFusionError::Plan( + format!("Unsupported type '{val}'. Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanoseconds, None)'. Error {msg}" ) + ) +} + +fn make_error_expected(val: &str, expected: &Token, actual: &Token) -> DataFusionError { + make_error(val, &format!("Expected '{expected}', got '{actual}'")) +} + +#[derive(Debug)] +/// Implementation of `parse_data_type`, modeled after +struct Parser<'a> { + val: &'a str, + tokenizer: Tokenizer<'a>, +} + +impl<'a> Parser<'a> { + fn new(val: &'a str) -> Self { + Self { + val, + tokenizer: Tokenizer::new(val), + } + } + + fn parse(mut self) -> Result { + let data_type = self.parse_next_type()?; + // ensure that there is no trailing content + if self.tokenizer.peek_next_char().is_some() { + return Err(make_error( + self.val, + &format!("checking trailing content after parsing '{data_type}'"), + )); + } else { + Ok(data_type) + } + } + + /// parses the next full DataType + fn parse_next_type(&mut self) -> Result { + match self.next_token()? { + Token::SimpleType(data_type) => Ok(data_type), + Token::Timestamp => self.parse_timestamp(), + Token::Time32 => self.parse_time32(), + Token::Time64 => self.parse_time64(), + Token::Duration => self.parse_duration(), + Token::Interval => self.parse_interval(), + Token::FixedSizeBinary => self.parse_fixed_size_binary(), + Token::Decimal128 => self.parse_decimal_128(), + Token::Decimal256 => self.parse_decimal_256(), + Token::Dictionary => self.parse_dictionary(), + tok => Err(make_error( + self.val, + &format!("finding next type, got unexpected '{tok}'"), + )), + } + } + + /// Parses the next timeunit + fn parse_time_unit(&mut self, context: &str) -> Result { + match self.next_token()? { + Token::TimeUnit(time_unit) => Ok(time_unit), + tok => Err(make_error( + self.val, + &format!("finding TimeUnit for {context}, got {tok}"), + )), + } + } + + /// Parses the next integer value + fn parse_i64(&mut self, context: &str) -> Result { + match self.next_token()? { + Token::Integer(v) => Ok(v), + tok => Err(make_error( + self.val, + &format!("finding i64 for {context}, got '{tok}'"), + )), + } + } + + /// Parses the next i32 integer value + fn parse_i32(&mut self, context: &str) -> Result { + let length = self.parse_i64(context)?; + length.try_into().map_err(|e| { + make_error( + self.val, + &format!("converting {length} into i32 for {context}: {e}"), + ) + }) + } + + /// Parses the next i8 integer value + fn parse_i8(&mut self, context: &str) -> Result { + let length = self.parse_i64(context)?; + length.try_into().map_err(|e| { + make_error( + self.val, + &format!("converting {length} into i8 for {context}: {e}"), + ) + }) + } + + /// Parses the next u8 integer value + fn parse_u8(&mut self, context: &str) -> Result { + let length = self.parse_i64(context)?; + length.try_into().map_err(|e| { + make_error( + self.val, + &format!("converting {length} into u8 for {context}: {e}"), + ) + }) + } + + /// Parses the next timestamp (called after `Timestamp` has been consumed) + fn parse_timestamp(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let time_unit = self.parse_time_unit("Timestamp")?; + self.expect_token(Token::Comma)?; + // TODO Support timezones other than Non + self.expect_token(Token::None)?; + let timezone = None; + + self.expect_token(Token::RParen)?; + Ok(DataType::Timestamp(time_unit, timezone)) + } + + /// Parses the next Time32 (called after `Time32` has been consumed) + fn parse_time32(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let time_unit = self.parse_time_unit("Time32")?; + self.expect_token(Token::RParen)?; + Ok(DataType::Time32(time_unit)) + } + + /// Parses the next Time64 (called after `Time64` has been consumed) + fn parse_time64(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let time_unit = self.parse_time_unit("Time64")?; + self.expect_token(Token::RParen)?; + Ok(DataType::Time64(time_unit)) + } + + /// Parses the next Duration (called after `Duration` has been consumed) + fn parse_duration(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let time_unit = self.parse_time_unit("Duration")?; + self.expect_token(Token::RParen)?; + Ok(DataType::Duration(time_unit)) + } + + /// Parses the next Interval (called after `Interval` has been consumed) + fn parse_interval(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let interval_unit = match self.next_token()? { + Token::IntervalUnit(interval_unit) => interval_unit, + tok => { + return Err(make_error( + self.val, + &format!("finding IntervalUnit for Interval, got {tok}"), + )) + } + }; + self.expect_token(Token::RParen)?; + Ok(DataType::Interval(interval_unit)) + } + + /// Parses the next FixedSizeBinary (called after `FixedSizeBinary` has been consumed) + fn parse_fixed_size_binary(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let length = self.parse_i32("FixedSizeBinary")?; + self.expect_token(Token::RParen)?; + Ok(DataType::FixedSizeBinary(length)) + } + + /// Parses the next Decimal128 (called after `Decimal128` has been consumed) + fn parse_decimal_128(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let precision = self.parse_u8("Decimal128")?; + self.expect_token(Token::Comma)?; + let scale = self.parse_i8("Decimal128")?; + self.expect_token(Token::RParen)?; + Ok(DataType::Decimal128(precision, scale)) + } + + /// Parses the next Decimal256 (called after `Decimal256` has been consumed) + fn parse_decimal_256(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let precision = self.parse_u8("Decimal256")?; + self.expect_token(Token::Comma)?; + let scale = self.parse_i8("Decimal256")?; + self.expect_token(Token::RParen)?; + Ok(DataType::Decimal256(precision, scale)) + } + + /// Parses the next Dictionary (called after `Dictionary` has been consumed) + fn parse_dictionary(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let key_type = self.parse_next_type()?; + self.expect_token(Token::Comma)?; + let value_type = self.parse_next_type()?; + self.expect_token(Token::RParen)?; + Ok(DataType::Dictionary( + Box::new(key_type), + Box::new(value_type), + )) + } + + /// return the next token, or an error if there are none left + fn next_token(&mut self) -> Result { + match self.tokenizer.next() { + None => Err(make_error(self.val, "finding next token")), + Some(token) => token, + } + } + + /// consume the next token, returning OK(()) if it matches tok, and Err if not + fn expect_token(&mut self, tok: Token) -> Result<()> { + let next_token = self.next_token()?; + if next_token == tok { + Ok(()) + } else { + Err(make_error_expected(self.val, &tok, &next_token)) + } + } +} + +/// returns true if this character is a separator +fn is_separator(c: char) -> bool { + c == '(' || c == ')' || c == ',' || c == ' ' +} + +#[derive(Debug)] +/// Splits a strings like Dictionary(Int32, Int64) into tokens sutable for parsing +/// +/// For example the string "Timestamp(Nanosecond, None)" would be parsed into: +/// +/// * Token::Timestamp +/// * Token::Lparen +/// * Token::IntervalUnit(IntervalUnit::Nanoseconds) +/// * Token::Comma, +/// * Token::None, +/// * Token::Rparen, +struct Tokenizer<'a> { + val: &'a str, + chars: Peekable>, +} + +impl<'a> Tokenizer<'a> { + fn new(val: &'a str) -> Self { + Self { + val, + chars: val.chars().peekable(), + } + } + + /// returns the next char, without consuming it + fn peek_next_char(&mut self) -> Option { + self.chars.peek().copied() + } + + /// returns the next char, and consuming it + fn next_char(&mut self) -> Option { + self.chars.next() + } + + /// parse the characters in val starting at pos, until the next + /// `,`, `(`, or `)` or end of line + fn parse_word(&mut self) -> Result { + let mut word = String::new(); + loop { + match self.peek_next_char() { + None => break, + Some(c) if is_separator(c) => break, + Some(c) => { + self.next_char(); + word.push(c); + } + } + } + + // if it started with a number, try parsing it as an integer + if let Some(c) = word.chars().next() { + if c == '-' || c.is_numeric() { + let val: i64 = word.parse().map_err(|e| { + make_error(self.val, &format!("parsing {word} as integer: {e}")) + })?; + return Ok(Token::Integer(val)); + } + } + + // figure out what the word was + let token = match word.as_str() { + "Null" => Token::SimpleType(DataType::Null), + "Boolean" => Token::SimpleType(DataType::Boolean), + + "Int8" => Token::SimpleType(DataType::Int8), + "Int16" => Token::SimpleType(DataType::Int16), + "Int32" => Token::SimpleType(DataType::Int32), + "Int64" => Token::SimpleType(DataType::Int64), + + "UInt8" => Token::SimpleType(DataType::UInt8), + "UInt16" => Token::SimpleType(DataType::UInt16), + "UInt32" => Token::SimpleType(DataType::UInt32), + "UInt64" => Token::SimpleType(DataType::UInt64), + + "Utf8" => Token::SimpleType(DataType::Utf8), + "LargeUtf8" => Token::SimpleType(DataType::LargeUtf8), + "Binary" => Token::SimpleType(DataType::Binary), + "LargeBinary" => Token::SimpleType(DataType::LargeBinary), + + "Float16" => Token::SimpleType(DataType::Float16), + "Float32" => Token::SimpleType(DataType::Float32), + "Float64" => Token::SimpleType(DataType::Float64), + + "Date32" => Token::SimpleType(DataType::Date32), + "Date64" => Token::SimpleType(DataType::Date64), + + "Second" => Token::TimeUnit(TimeUnit::Second), + "Millisecond" => Token::TimeUnit(TimeUnit::Millisecond), + "Microsecond" => Token::TimeUnit(TimeUnit::Microsecond), + "Nanosecond" => Token::TimeUnit(TimeUnit::Nanosecond), + + "Timestamp" => Token::Timestamp, + "Time32" => Token::Time32, + "Time64" => Token::Time64, + "Duration" => Token::Duration, + "Interval" => Token::Interval, + "Dictionary" => Token::Dictionary, + + "FixedSizeBinary" => Token::FixedSizeBinary, + "Decimal128" => Token::Decimal128, + "Decimal256" => Token::Decimal256, + + "YearMonth" => Token::IntervalUnit(IntervalUnit::YearMonth), + "DayTime" => Token::IntervalUnit(IntervalUnit::DayTime), + "MonthDayNano" => Token::IntervalUnit(IntervalUnit::MonthDayNano), + + "None" => Token::None, + + _ => return Err(make_error(self.val, &format!("unrecognized word: {word}"))), + }; + Ok(token) + } +} + +impl<'a> Iterator for Tokenizer<'a> { + type Item = Result; + + fn next(&mut self) -> Option { + loop { + match self.peek_next_char()? { + ' ' => { + // skip whitespace + self.next_char(); + continue; + } + '(' => { + self.next_char(); + return Some(Ok(Token::LParen)); + } + ')' => { + self.next_char(); + return Some(Ok(Token::RParen)); + } + ',' => { + self.next_char(); + return Some(Ok(Token::Comma)); + } + _ => return Some(self.parse_word()), + } + } + } +} + +/// Grammar is +/// +#[derive(Debug, PartialEq)] +enum Token { + // Null, or Int32 + SimpleType(DataType), + Timestamp, + Time32, + Time64, + Duration, + Interval, + FixedSizeBinary, + Decimal128, + Decimal256, + Dictionary, + TimeUnit(TimeUnit), + IntervalUnit(IntervalUnit), + LParen, + RParen, + Comma, + None, + Integer(i64), +} + +impl Display for Token { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Token::SimpleType(t) => write!(f, "{t}"), + Token::Timestamp => write!(f, "Timestamp"), + Token::Time32 => write!(f, "Time32"), + Token::Time64 => write!(f, "Time64"), + Token::Duration => write!(f, "Duration"), + Token::Interval => write!(f, "Interval"), + Token::TimeUnit(u) => write!(f, "TimeUnit({u:?})"), + Token::IntervalUnit(u) => write!(f, "IntervalUnit({u:?})"), + Token::LParen => write!(f, "("), + Token::RParen => write!(f, ")"), + Token::Comma => write!(f, ","), + Token::None => write!(f, "None"), + Token::FixedSizeBinary => write!(f, "FixedSizeBinary"), + Token::Decimal128 => write!(f, "Decimal128"), + Token::Decimal256 => write!(f, "Decimal256"), + Token::Dictionary => write!(f, "Dictionary"), + Token::Integer(v) => write!(f, "Integer({v})"), + } + } +} + +#[cfg(test)] +mod test { + use arrow_schema::{IntervalUnit, TimeUnit}; + + use super::*; + + #[test] + fn test_parse_data_type() { + // this ensures types can be parsed correctly from their string representations + for dt in list_datatypes() { + round_trip(dt) + } + } + + /// convert data_type to a string, and then parse it as a type + /// verifying it is the same + fn round_trip(data_type: DataType) { + let data_type_string = data_type.to_string(); + println!("Input '{data_type_string}' ({data_type:?})"); + let parsed_type = parse_data_type(&data_type_string).unwrap(); + assert_eq!( + data_type, parsed_type, + "Mismatch parsing {data_type_string}" + ); + } + + fn list_datatypes() -> Vec { + vec![ + // --------- + // Non Nested types + // --------- + DataType::Null, + DataType::Boolean, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Float16, + DataType::Float32, + DataType::Float64, + DataType::Timestamp(TimeUnit::Second, None), + DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, None), + // TODO support timezones + //DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), + DataType::Date32, + DataType::Date64, + DataType::Time32(TimeUnit::Second), + DataType::Time32(TimeUnit::Millisecond), + DataType::Time32(TimeUnit::Microsecond), + DataType::Time32(TimeUnit::Nanosecond), + DataType::Time64(TimeUnit::Second), + DataType::Time64(TimeUnit::Millisecond), + DataType::Time64(TimeUnit::Microsecond), + DataType::Time64(TimeUnit::Nanosecond), + DataType::Duration(TimeUnit::Second), + DataType::Duration(TimeUnit::Millisecond), + DataType::Duration(TimeUnit::Microsecond), + DataType::Duration(TimeUnit::Nanosecond), + DataType::Interval(IntervalUnit::YearMonth), + DataType::Interval(IntervalUnit::DayTime), + DataType::Interval(IntervalUnit::MonthDayNano), + DataType::Binary, + DataType::FixedSizeBinary(0), + DataType::FixedSizeBinary(1234), + DataType::FixedSizeBinary(-432), + DataType::LargeBinary, + DataType::Utf8, + DataType::LargeUtf8, + DataType::Decimal128(7, 12), + DataType::Decimal256(6, 13), + // --------- + // Nested types + // --------- + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::FixedSizeBinary(23)), + ), + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new( + // nested dictionaries are probably a bad idea but they are possible + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Utf8), + ), + ), + ), + // TODO support more structured types (List, LargeList, Struct, Union, Map, RunEndEncoded, etc) + ] + } + + #[test] + fn parse_data_type_errors() { + // (string to parse, expected error message) + let cases = [ + ("", "Unsupported type ''"), + ("", "Error finding next token"), + ("null", "Unsupported type 'null'"), + ("Nu", "Unsupported type 'Nu'"), + // TODO support timezones + ( + r#"Timestamp(Nanosecond, Some("UTC"))"#, + "Error unrecognized word: Some", + ), + ("Timestamp(Nanosecond, ", "Error finding next token"), + ( + "Float32 Float32", + "trailing content after parsing 'Float32'", + ), + ("Int32, ", "trailing content after parsing 'Int32'"), + ("Int32(3), ", "trailing content after parsing 'Int32'"), + ("FixedSizeBinary(Int32), ", "Error finding i64 for FixedSizeBinary, got 'Int32'"), + ("FixedSizeBinary(3.0), ", "Error parsing 3.0 as integer: invalid digit found in string"), + // too large for i32 + ("FixedSizeBinary(4000000000), ", "Error converting 4000000000 into i32 for FixedSizeBinary: out of range integral type conversion attempted"), + // can't have negative precision + ("Decimal128(-3, 5)", "Error converting -3 into u8 for Decimal128: out of range integral type conversion attempted"), + ("Decimal256(-3, 5)", "Error converting -3 into u8 for Decimal256: out of range integral type conversion attempted"), + ("Decimal128(3, 500)", "Error converting 500 into i8 for Decimal128: out of range integral type conversion attempted"), + ("Decimal256(3, 500)", "Error converting 500 into i8 for Decimal256: out of range integral type conversion attempted"), + + ]; + + for (data_type_string, expected_message) in cases { + print!("Parsing '{data_type_string}', expecting '{expected_message}'"); + match parse_data_type(data_type_string) { + Ok(d) => panic!( + "Expected error while parsing '{data_type_string}', but got '{d}'" + ), + Err(e) => { + let message = e.to_string(); + assert!( + message.contains(expected_message), + "\n\ndid not find expected in actual.\n\nexpected: {expected_message}\nactual:{message}\n" + ); + // errors should also contain a help message + assert!(message.contains("Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanoseconds, None)'")); + } + } + println!(" Ok"); + } + } +} diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index c5f23213aa31c..68a5df054b69a 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -29,6 +29,8 @@ use sqlparser::ast::{ }; use std::str::FromStr; +use super::arrow_cast::ARROW_CAST_NAME; + impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn sql_function_to_expr( &self, @@ -110,24 +112,29 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; // finally, user-defined functions (UDF) and UDAF - match self.schema_provider.get_function_meta(&name) { - Some(fm) => { - let args = self.function_args_to_expr(function.args, schema)?; + if let Some(fm) = self.schema_provider.get_function_meta(&name) { + let args = self.function_args_to_expr(function.args, schema)?; + return Ok(Expr::ScalarUDF { fun: fm, args }); + } - Ok(Expr::ScalarUDF { fun: fm, args }) - } - None => match self.schema_provider.get_aggregate_meta(&name) { - Some(fm) => { - let args = self.function_args_to_expr(function.args, schema)?; - Ok(Expr::AggregateUDF { - fun: fm, - args, - filter: None, - }) - } - _ => Err(DataFusionError::Plan(format!("Invalid function '{name}'"))), - }, + // User defined aggregate functions + if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) { + let args = self.function_args_to_expr(function.args, schema)?; + return Ok(Expr::AggregateUDF { + fun: fm, + args, + filter: None, + }); } + + // Special case arrow_cast (as its type is dependent on its argument value) + if name == ARROW_CAST_NAME { + let args = self.function_args_to_expr(function.args, schema)?; + return super::arrow_cast::create_arrow_cast(args, schema); + } + + // Could not find the relevant function, so return an error + Err(DataFusionError::Plan(format!("Invalid function '{name}'"))) } pub(super) fn sql_named_function_to_expr( diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index f226924516ef8..ad05fbcc16ca8 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub(crate) mod arrow_cast; mod binary_op; mod function; mod grouping_set; diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index efe239f458111..c0c1a4ac91186 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -30,4 +30,5 @@ pub mod utils; mod values; pub use datafusion_common::{ResolvedTableReference, TableReference}; +pub use expr::arrow_cast::parse_data_type; pub use sqlparser; diff --git a/datafusion/sql/tests/integration_test.rs b/datafusion/sql/tests/integration_test.rs index 44c0559ef35a3..a57e8d2a260ce 100644 --- a/datafusion/sql/tests/integration_test.rs +++ b/datafusion/sql/tests/integration_test.rs @@ -2311,6 +2311,15 @@ fn approx_median_window() { quick_test(sql, expected); } +#[test] +fn select_arrow_cast() { + let sql = "SELECT arrow_cast(1234, 'Float64'), arrow_cast('foo', 'LargeUtf8')"; + let expected = "\ + Projection: CAST(Int64(1234) AS Float64), CAST(Utf8(\"foo\") AS LargeUtf8)\ + \n EmptyRelation"; + quick_test(sql, expected); +} + #[test] fn select_typed_date_string() { let sql = "SELECT date '2020-12-10' AS date"; @@ -2534,7 +2543,7 @@ impl ContextProvider for MockContextProvider { } fn get_function_meta(&self, _name: &str) -> Option> { - unimplemented!() + None } fn get_aggregate_meta(&self, name: &str) -> Option> { diff --git a/docs/source/user-guide/sql/data_types.md b/docs/source/user-guide/sql/data_types.md index 968dcda53281d..86976bb03ee24 100644 --- a/docs/source/user-guide/sql/data_types.md +++ b/docs/source/user-guide/sql/data_types.md @@ -37,6 +37,18 @@ the `arrow_typeof` function. For example: +-------------------------------------+ ``` +You can cast a SQL expression to a specific Arrow type using the `arrow_cast` function +For example, to cast the output of `now()` to a `Timestamp` with second precision rather: + +```sql +❯ select arrow_cast(now(), 'Timestamp(Second, None)'); ++---------------------+ +| now() | ++---------------------+ +| 2023-03-03T17:19:21 | ++---------------------+ +``` + ## Character Types | SQL DataType | Arrow DataType | @@ -68,9 +80,9 @@ the `arrow_typeof` function. For example: | SQL DataType | Arrow DataType | | ------------ | :----------------------------------------------------------------------- | | `DATE` | `Date32` | -| `TIME` | `Time64(TimeUnit::Nanosecond)` | -| `TIMESTAMP` | `Timestamp(TimeUnit::Nanosecond, None)` | -| `INTERVAL` | `Interval(IntervalUnit::YearMonth)` or `Interval(IntervalUnit::DayTime)` | +| `TIME` | `Time64(Nanosecond)` | +| `TIMESTAMP` | `Timestamp(Nanosecond, None)` | +| `INTERVAL` | `Interval(IntervalUnit)` or `Interval(DayTime)` | ## Boolean Types @@ -84,7 +96,7 @@ the `arrow_typeof` function. For example: | ------------ | :------------- | | `BYTEA` | `Binary` | -## Unsupported Types +## Unsupported SQL Types | SQL Data Type | Arrow DataType | | ------------- | :------------------ | @@ -100,3 +112,43 @@ the `arrow_typeof` function. For example: | `ENUM` | _Not yet supported_ | | `SET` | _Not yet supported_ | | `DATETIME` | _Not yet supported_ | + +## Supported Arrow Types + +The following types are supported by the `arrow_typeof` function: + +| Arrow Type | +|--------------------------------| +| `Null` | +| `Boolean` | +| `Int8` | +| `Int16` | +| `Int32` | +| `Int64` | +| `UInt8` | +| `UInt16` | +| `UInt32` | +| `UInt64` | +| `Float16` | +| `Float32` | +| `Float64` | +| `Utf8` | +| `LargeUtf8` | +| `Binary` | +| `Timestamp(Second, None)` | +| `Timestamp(Millisecond, None)` | +| `Timestamp(Microsecond, None)` | +| `Timestamp(Nanosecond, None)` | +| `Time32` | +| `Time64` | +| `Duration(Second)` | +| `Duration(Millisecond)` | +| `Duration(Microsecond)` | +| `Duration(Nanosecond)` | +| `Interval(YearMonth)` | +| `Interval(DayTime)` | +| `Interval(MonthDayNano)` | +| `Interval(MonthDayNano)` | +| `FixedSizeBinary()` (e.g. `FixedSizeBinary(16)`) | +| `Decimal128(, )` e.g. `Decimal128(3, 10)` | +| `Decimal256(, )` e.g. `Decimal256(3, 10)` | From a58f00a2de1612db1d1a8ee25d93d8e1d4e60058 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 3 Mar 2023 13:03:33 -0500 Subject: [PATCH 02/11] prettier --- docs/source/user-guide/sql/data_types.md | 82 ++++++++++++------------ 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/docs/source/user-guide/sql/data_types.md b/docs/source/user-guide/sql/data_types.md index 86976bb03ee24..9f0ca8f89467b 100644 --- a/docs/source/user-guide/sql/data_types.md +++ b/docs/source/user-guide/sql/data_types.md @@ -77,12 +77,12 @@ For example, to cast the output of `now()` to a `Timestamp` with second precisio ## Date/Time Types -| SQL DataType | Arrow DataType | -| ------------ | :----------------------------------------------------------------------- | -| `DATE` | `Date32` | -| `TIME` | `Time64(Nanosecond)` | -| `TIMESTAMP` | `Timestamp(Nanosecond, None)` | -| `INTERVAL` | `Interval(IntervalUnit)` or `Interval(DayTime)` | +| SQL DataType | Arrow DataType | +| ------------ | :---------------------------------------------- | +| `DATE` | `Date32` | +| `TIME` | `Time64(Nanosecond)` | +| `TIMESTAMP` | `Timestamp(Nanosecond, None)` | +| `INTERVAL` | `Interval(IntervalUnit)` or `Interval(DayTime)` | ## Boolean Types @@ -117,38 +117,38 @@ For example, to cast the output of `now()` to a `Timestamp` with second precisio The following types are supported by the `arrow_typeof` function: -| Arrow Type | -|--------------------------------| -| `Null` | -| `Boolean` | -| `Int8` | -| `Int16` | -| `Int32` | -| `Int64` | -| `UInt8` | -| `UInt16` | -| `UInt32` | -| `UInt64` | -| `Float16` | -| `Float32` | -| `Float64` | -| `Utf8` | -| `LargeUtf8` | -| `Binary` | -| `Timestamp(Second, None)` | -| `Timestamp(Millisecond, None)` | -| `Timestamp(Microsecond, None)` | -| `Timestamp(Nanosecond, None)` | -| `Time32` | -| `Time64` | -| `Duration(Second)` | -| `Duration(Millisecond)` | -| `Duration(Microsecond)` | -| `Duration(Nanosecond)` | -| `Interval(YearMonth)` | -| `Interval(DayTime)` | -| `Interval(MonthDayNano)` | -| `Interval(MonthDayNano)` | -| `FixedSizeBinary()` (e.g. `FixedSizeBinary(16)`) | -| `Decimal128(, )` e.g. `Decimal128(3, 10)` | -| `Decimal256(, )` e.g. `Decimal256(3, 10)` | +| Arrow Type | +| ----------------------------------------------------------- | +| `Null` | +| `Boolean` | +| `Int8` | +| `Int16` | +| `Int32` | +| `Int64` | +| `UInt8` | +| `UInt16` | +| `UInt32` | +| `UInt64` | +| `Float16` | +| `Float32` | +| `Float64` | +| `Utf8` | +| `LargeUtf8` | +| `Binary` | +| `Timestamp(Second, None)` | +| `Timestamp(Millisecond, None)` | +| `Timestamp(Microsecond, None)` | +| `Timestamp(Nanosecond, None)` | +| `Time32` | +| `Time64` | +| `Duration(Second)` | +| `Duration(Millisecond)` | +| `Duration(Microsecond)` | +| `Duration(Nanosecond)` | +| `Interval(YearMonth)` | +| `Interval(DayTime)` | +| `Interval(MonthDayNano)` | +| `Interval(MonthDayNano)` | +| `FixedSizeBinary()` (e.g. `FixedSizeBinary(16)`) | +| `Decimal128(, )` e.g. `Decimal128(3, 10)` | +| `Decimal256(, )` e.g. `Decimal256(3, 10)` | From 78965dbaa8bb296974fb1640c260117291aa37ad Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 5 Mar 2023 06:56:23 -0500 Subject: [PATCH 03/11] Update datafusion/sql/src/expr/arrow_cast.rs Co-authored-by: Wei-Ting Kuo --- datafusion/sql/src/expr/arrow_cast.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index 806ee251826c0..8731123327776 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -216,7 +216,7 @@ impl<'a> Parser<'a> { self.expect_token(Token::LParen)?; let time_unit = self.parse_time_unit("Timestamp")?; self.expect_token(Token::Comma)?; - // TODO Support timezones other than Non + // TODO Support timezones other than None self.expect_token(Token::None)?; let timezone = None; From 3df081f188e86eaa1517a5986e07108ee26db1ca Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 5 Mar 2023 06:59:48 -0500 Subject: [PATCH 04/11] Apply suggestions from code review Co-authored-by: Wei-Ting Kuo --- datafusion/sql/src/expr/arrow_cast.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index 8731123327776..a52bc3064cd90 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -100,7 +100,7 @@ pub fn parse_data_type(val: &str) -> Result { fn make_error(val: &str, msg: &str) -> DataFusionError { DataFusionError::Plan( - format!("Unsupported type '{val}'. Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanoseconds, None)'. Error {msg}" ) + format!("Unsupported type '{val}'. Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'. Error {msg}" ) ) } @@ -336,7 +336,7 @@ fn is_separator(c: char) -> bool { /// /// * Token::Timestamp /// * Token::Lparen -/// * Token::IntervalUnit(IntervalUnit::Nanoseconds) +/// * Token::IntervalUnit(IntervalUnit::Nanosecond) /// * Token::Comma, /// * Token::None, /// * Token::Rparen, @@ -664,7 +664,7 @@ mod test { "\n\ndid not find expected in actual.\n\nexpected: {expected_message}\nactual:{message}\n" ); // errors should also contain a help message - assert!(message.contains("Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanoseconds, None)'")); + assert!(message.contains("Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'")); } } println!(" Ok"); From d3b17e8e5d8ebf538b53502066b4538211544f3f Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 8 Mar 2023 13:59:24 -0500 Subject: [PATCH 05/11] Clarify intent of tests --- .../sqllogictests/test_files/arrow_typeof.slt | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt index 72031471bac64..2091fb107b456 100644 --- a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt +++ b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt @@ -84,12 +84,14 @@ Utf8 #### arrow_cast (in some ways opposite of arrow_typeof) +# Basic tests query I SELECT arrow_cast('1', 'Int16') ---- 1 +# Basic error test query error Error during planning: arrow_cast needs 2 arguments, 1 provided SELECT arrow_cast('1') @@ -100,7 +102,7 @@ query error Error unrecognized word: unknown SELECT arrow_cast('1', 'unknown') -## Basic types +## Basic Types: Create a table statement ok create table foo as select @@ -117,6 +119,7 @@ create table foo as select arrow_cast(1.0, 'Float64') as col_f64 ; +## Ensure each column in the table has the expected type query TTTTTTTTTT SELECT @@ -135,11 +138,10 @@ SELECT Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float32 Float64 - statement ok drop table foo -## Decimals +## Decimals: Create a table statement ok create table foo as select @@ -150,6 +152,8 @@ create table foo as select ; +## Ensure each column in the table has the expected type + query T SELECT arrow_typeof(col_d128) @@ -162,7 +166,7 @@ Decimal128(3, 2) statement ok drop table foo -## strings, large strings +## Strings, Binary: Create a table statement ok create table foo as select @@ -172,6 +176,7 @@ create table foo as select arrow_cast('foo', 'LargeBinary') as col_large_binary ; +## Ensure each column in the table has the expected type query TTTT SELECT @@ -188,7 +193,7 @@ statement ok drop table foo -## timestamps +## Timestamps: Create a table statement ok create table foo as select @@ -198,6 +203,7 @@ create table foo as select arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Nanosecond, None)') as col_ts_ns ; +## Ensure each column in the table has the expected type query TTTT SELECT @@ -214,7 +220,7 @@ statement ok drop table foo -## durations +## Durations: Create a table statement ok create table foo as select @@ -224,6 +230,7 @@ create table foo as select arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Nanosecond, None)') as col_ts_ns ; +## Ensure each column in the table has the expected type query TTTT SELECT @@ -240,14 +247,14 @@ statement ok drop table foo -## intervals +## Intervals: query error Cannot automatically convert Interval\(DayTime\) to Interval\(MonthDayNano\) --- select arrow_cast(interval '30 minutes', 'Interval(MonthDayNano)'); -## duration +## Duration query error Cannot automatically convert Interval\(DayTime\) to Duration\(Second\) --- From 7fa282bd5767f64ac5bcebfefa8eca35aea309f1 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 8 Mar 2023 14:01:50 -0500 Subject: [PATCH 06/11] Add more error tests --- .../core/tests/sqllogictests/test_files/arrow_typeof.slt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt index 2091fb107b456..22cc7afe7231f 100644 --- a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt +++ b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt @@ -253,9 +253,15 @@ query error Cannot automatically convert Interval\(DayTime\) to Interval\(MonthD --- select arrow_cast(interval '30 minutes', 'Interval(MonthDayNano)'); +query error DataFusion error: Error during planning: Cannot automatically convert Utf8 to Interval(MonthDayNano) +select arrow_cast('30 minutes', 'Interval(MonthDayNano)'); + ## Duration query error Cannot automatically convert Interval\(DayTime\) to Duration\(Second\) --- select arrow_cast(interval '30 minutes', 'Duration(Second)'); + +query error DataFusion error: Error during planning: Cannot automatically convert Utf8 to Duration(Second) +select arrow_cast('30 minutes', 'Duration(Second)'); From a4f27534dec7088e656c14523957871a6ff26f81 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 8 Mar 2023 14:12:51 -0500 Subject: [PATCH 07/11] More tests --- .../sqllogictests/test_files/arrow_typeof.slt | 50 ++++++++++++++----- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt index 22cc7afe7231f..874080d0d5e2a 100644 --- a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt +++ b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt @@ -101,6 +101,35 @@ SELECT arrow_cast('1', 43) query error Error unrecognized word: unknown SELECT arrow_cast('1', 'unknown') +# Round Trip tests: +query TTTTTTTTTTTTTTTTTTT +SELECT + arrow_typeof(arrow_cast(1, 'Int8')) as col_i8, + arrow_typeof(arrow_cast(1, 'Int16')) as col_i16, + arrow_typeof(arrow_cast(1, 'Int32')) as col_i32, + arrow_typeof(arrow_cast(1, 'Int64')) as col_i64, + arrow_typeof(arrow_cast(1, 'UInt8')) as col_u8, + arrow_typeof(arrow_cast(1, 'UInt16')) as col_u16, + arrow_typeof(arrow_cast(1, 'UInt32')) as col_u32, + arrow_typeof(arrow_cast(1, 'UInt64')) as col_u64, + -- can't seem to cast to Float16 for some reason + -- arrow_typeof(arrow_cast(1, 'Float16')) as col_f16, + arrow_typeof(arrow_cast(1, 'Float32')) as col_f32, + arrow_typeof(arrow_cast(1, 'Float64')) as col_f64, + arrow_typeof(arrow_cast('foo', 'Utf8')) as col_utf8, + arrow_typeof(arrow_cast('foo', 'LargeUtf8')) as col_large_utf8, + arrow_typeof(arrow_cast('foo', 'Binary')) as col_binary, + arrow_typeof(arrow_cast('foo', 'LargeBinary')) as col_large_binary, + arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)')) as col_ts_s, + arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Millisecond, None)')) as col_ts_ms, + arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Microsecond, None)')) as col_ts_us, + arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Nanosecond, None)')) as col_ts_ns, + arrow_typeof(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) as col_dict +---- +Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float32 Float64 Utf8 LargeUtf8 Binary LargeBinary Timestamp(Second, None) Timestamp(Millisecond, None) Timestamp(Microsecond, None) Timestamp(Nanosecond, None) Dictionary(Int32, Utf8) + + + ## Basic Types: Create a table @@ -115,6 +144,7 @@ create table foo as select arrow_cast(1, 'UInt32') as col_u32, arrow_cast(1, 'UInt64') as col_u64, -- can't seem to cast to Float16 for some reason + -- arrow_cast(1.0, 'Float16') as col_f16, arrow_cast(1.0, 'Float32') as col_f32, arrow_cast(1.0, 'Float64') as col_f64 ; @@ -131,6 +161,7 @@ SELECT arrow_typeof(col_u16), arrow_typeof(col_u32), arrow_typeof(col_u64), + -- arrow_typeof(col_f16), arrow_typeof(col_f32), arrow_typeof(col_f64) FROM foo; @@ -219,28 +250,23 @@ Timestamp(Second, None) Timestamp(Millisecond, None) Timestamp(Microsecond, None statement ok drop table foo - -## Durations: Create a table +## Dictionaries statement ok create table foo as select - arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)') as col_ts_s, - arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Millisecond, None)') as col_ts_ms, - arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Microsecond, None)') as col_ts_us, - arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Nanosecond, None)') as col_ts_ns + arrow_cast('foo', 'Dictionary(Int32, Utf8)') as col_dict_int32_utf8, + arrow_cast('foo', 'Dictionary(Int8, LargeUtf8)') as col_dict_int8_largeutf8 ; ## Ensure each column in the table has the expected type -query TTTT +query TT SELECT - arrow_typeof(col_ts_s), - arrow_typeof(col_ts_ms), - arrow_typeof(col_ts_us), - arrow_typeof(col_ts_ns) + arrow_typeof(col_dict_int32_utf8), + arrow_typeof(col_dict_int8_largeutf8) FROM foo; ---- -Timestamp(Second, None) Timestamp(Millisecond, None) Timestamp(Microsecond, None) Timestamp(Nanosecond, None) +Dictionary(Int32, Utf8) Dictionary(Int8, LargeUtf8) statement ok From b3defe4f8e3e11360cab4f371daf1069644b3f01 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 8 Mar 2023 14:51:16 -0500 Subject: [PATCH 08/11] fix test --- .../core/tests/sqllogictests/test_files/arrow_typeof.slt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt index 874080d0d5e2a..fee24740a6f19 100644 --- a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt +++ b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt @@ -279,7 +279,7 @@ query error Cannot automatically convert Interval\(DayTime\) to Interval\(MonthD --- select arrow_cast(interval '30 minutes', 'Interval(MonthDayNano)'); -query error DataFusion error: Error during planning: Cannot automatically convert Utf8 to Interval(MonthDayNano) +query error DataFusion error: Error during planning: Cannot automatically convert Utf8 to Interval\(MonthDayNano\) select arrow_cast('30 minutes', 'Interval(MonthDayNano)'); @@ -289,5 +289,5 @@ query error Cannot automatically convert Interval\(DayTime\) to Duration\(Second --- select arrow_cast(interval '30 minutes', 'Duration(Second)'); -query error DataFusion error: Error during planning: Cannot automatically convert Utf8 to Duration(Second) +query error DataFusion error: Error during planning: Cannot automatically convert Utf8 to Duration\(Second\) select arrow_cast('30 minutes', 'Duration(Second)'); From 84901bf62b7cadddb82e1fb53158f49f6cee9f7e Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 8 Mar 2023 14:54:56 -0500 Subject: [PATCH 09/11] reuse buffer to avoid an allocation per word --- datafusion/sql/src/expr/arrow_cast.rs | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index a52bc3064cd90..126e730229b89 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -343,6 +343,8 @@ fn is_separator(c: char) -> bool { struct Tokenizer<'a> { val: &'a str, chars: Peekable>, + // temporary buffer for parsing words + word: String, } impl<'a> Tokenizer<'a> { @@ -350,6 +352,7 @@ impl<'a> Tokenizer<'a> { Self { val, chars: val.chars().peekable(), + word: String::new(), } } @@ -366,30 +369,34 @@ impl<'a> Tokenizer<'a> { /// parse the characters in val starting at pos, until the next /// `,`, `(`, or `)` or end of line fn parse_word(&mut self) -> Result { - let mut word = String::new(); + // reset temp space + self.word.clear(); loop { match self.peek_next_char() { None => break, Some(c) if is_separator(c) => break, Some(c) => { self.next_char(); - word.push(c); + self.word.push(c); } } } // if it started with a number, try parsing it as an integer - if let Some(c) = word.chars().next() { + if let Some(c) = self.word.chars().next() { if c == '-' || c.is_numeric() { - let val: i64 = word.parse().map_err(|e| { - make_error(self.val, &format!("parsing {word} as integer: {e}")) + let val: i64 = self.word.parse().map_err(|e| { + make_error( + self.val, + &format!("parsing {} as integer: {e}", self.word), + ) })?; return Ok(Token::Integer(val)); } } // figure out what the word was - let token = match word.as_str() { + let token = match self.word.as_str() { "Null" => Token::SimpleType(DataType::Null), "Boolean" => Token::SimpleType(DataType::Boolean), @@ -437,7 +444,12 @@ impl<'a> Tokenizer<'a> { "None" => Token::None, - _ => return Err(make_error(self.val, &format!("unrecognized word: {word}"))), + _ => { + return Err(make_error( + self.val, + &format!("unrecognized word: {}", self.word), + )) + } }; Ok(token) } From 1e6b64801b3955ef6d5759f22e010b919af63de1 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 8 Mar 2023 15:10:18 -0500 Subject: [PATCH 10/11] add ticket link --- datafusion/sql/src/expr/arrow_cast.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index 126e730229b89..eac5f128bf6b3 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -93,7 +93,7 @@ pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result /// assert_eq!(data_type, DataType::Int32); /// ``` /// -/// TODO file a ticket about bringing this into arrow possibly +/// Remove if added to arrow: https://github.com/apache/arrow-rs/issues/3821 pub fn parse_data_type(val: &str) -> Result { Parser::new(val).parse() } From ff5d72bfe304d394db406381752ebdbae7a859ce Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 8 Mar 2023 15:38:27 -0500 Subject: [PATCH 11/11] allow trailing whitespace, add tests for whitespace --- datafusion/sql/src/expr/arrow_cast.rs | 36 ++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index eac5f128bf6b3..bc1313e2c114e 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -126,7 +126,7 @@ impl<'a> Parser<'a> { fn parse(mut self) -> Result { let data_type = self.parse_next_type()?; // ensure that there is no trailing content - if self.tokenizer.peek_next_char().is_some() { + if self.tokenizer.next().is_some() { return Err(make_error( self.val, &format!("checking trailing content after parsing '{data_type}'"), @@ -613,6 +613,10 @@ mod test { // --------- DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Timestamp(TimeUnit::Nanosecond, None)), + ), DataType::Dictionary( Box::new(DataType::Int8), Box::new(DataType::FixedSizeBinary(23)), @@ -631,6 +635,36 @@ mod test { ] } + #[test] + fn test_parse_data_type_whitespace_tolerance() { + // (string to parse, expected DataType) + let cases = [ + ("Int8", DataType::Int8), + ( + "Timestamp (Nanosecond, None)", + DataType::Timestamp(TimeUnit::Nanosecond, None), + ), + ( + "Timestamp (Nanosecond, None) ", + DataType::Timestamp(TimeUnit::Nanosecond, None), + ), + ( + " Timestamp (Nanosecond, None )", + DataType::Timestamp(TimeUnit::Nanosecond, None), + ), + ( + "Timestamp (Nanosecond, None ) ", + DataType::Timestamp(TimeUnit::Nanosecond, None), + ), + ]; + + for (data_type_string, expected_data_type) in cases { + println!("Parsing '{data_type_string}', expecting '{expected_data_type:?}'"); + let parsed_data_type = parse_data_type(data_type_string).unwrap(); + assert_eq!(parsed_data_type, expected_data_type); + } + } + #[test] fn parse_data_type_errors() { // (string to parse, expected error message)