diff --git a/datafusion/core/tests/sql/cast.rs b/datafusion/core/tests/sql/cast.rs new file mode 100644 index 0000000000000..61bac0eb2c22e --- /dev/null +++ b/datafusion/core/tests/sql/cast.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::sql::execute_to_batches; +use arrow::datatypes::DataType; +use arrow::record_batch::RecordBatch; +use datafusion::error::Result; +use datafusion::prelude::SessionContext; + +async fn execute_sql(sql: &str) -> Vec { + let ctx = SessionContext::new(); + execute_to_batches(&ctx, sql).await +} + +#[tokio::test] +async fn cast_tinyint() -> Result<()> { + let actual = execute_sql("SELECT cast(10 as tinyint)").await; + assert_eq!(&DataType::Int8, actual[0].schema().field(0).data_type()); + Ok(()) +} + +#[tokio::test] +async fn cast_tinyint_operator() -> Result<()> { + let actual = execute_sql("SELECT 10::tinyint").await; + assert_eq!(&DataType::Int8, actual[0].schema().field(0).data_type()); + Ok(()) +} + +#[tokio::test] +async fn cast_unsigned_tinyint() -> Result<()> { + let actual = execute_sql("SELECT 10::tinyint unsigned").await; + assert_eq!(&DataType::UInt8, actual[0].schema().field(0).data_type()); + Ok(()) +} + +#[tokio::test] +async fn cast_unsigned_smallint() -> Result<()> { + let actual = execute_sql("SELECT 10::smallint unsigned").await; + assert_eq!(&DataType::UInt16, actual[0].schema().field(0).data_type()); + Ok(()) +} + +#[tokio::test] +async fn cast_unsigned_int() -> Result<()> { + let actual = execute_sql("SELECT 10::integer unsigned").await; + assert_eq!(&DataType::UInt32, actual[0].schema().field(0).data_type()); + Ok(()) +} + +#[tokio::test] +async fn cast_unsigned_bigint() -> Result<()> { + let actual = execute_sql("SELECT 10::bigint unsigned").await; + assert_eq!(&DataType::UInt64, actual[0].schema().field(0).data_type()); + Ok(()) +} diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 473ded8ff7358..894d45564272d 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -196,8 +196,8 @@ async fn csv_explain_plans() { let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: #aggregate_test_100.c1 [c1:Utf8]", - " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", - " TableScan: aggregate_test_100 [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", + " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", + " TableScan: aggregate_test_100 [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -243,9 +243,9 @@ async fn csv_explain_plans() { " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", + " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]", " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", + " 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]", " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", " }", "}", @@ -271,8 +271,8 @@ async fn csv_explain_plans() { let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: #aggregate_test_100.c1 [c1:Utf8]", - " Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]", + " Filter: #aggregate_test_100.c2 > Int8(10) [c1:Utf8, c2:Int8]", + " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)] [c1:Utf8, c2:Int8]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -286,8 +286,8 @@ async fn csv_explain_plans() { let expected = vec![ "Explain", " Projection: #aggregate_test_100.c1", - " Filter: #aggregate_test_100.c2 > Int32(10)", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]", + " Filter: #aggregate_test_100.c2 > Int8(10)", + " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)]", ]; let formatted = plan.display_indent().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -307,9 +307,9 @@ async fn csv_explain_plans() { " 2[shape=box label=\"Explain\"]", " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]", + " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int8(10)\"]", " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]", + " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)]\"]", " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", " }", " subgraph cluster_6", @@ -318,9 +318,9 @@ async fn csv_explain_plans() { " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int8(10)\\nSchema: [c1:Utf8, c2:Int8]\"]", " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)]\\nSchema: [c1:Utf8, c2:Int8]\"]", " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", " }", "}", @@ -349,7 +349,7 @@ async fn csv_explain_plans() { // Since the plan contains path that are environmentally dependant (e.g. full path of the test file), only verify important content assert_contains!(&actual, "logical_plan"); assert_contains!(&actual, "Projection: #aggregate_test_100.c1"); - assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int32(10)"); + assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int8(10)"); } #[tokio::test] @@ -381,8 +381,8 @@ async fn csv_explain_inlist_verbose() { let actual = execute(&ctx, sql).await; // Optimized by PreCastLitInComparisonExpressions rule - // the data type of c2 is INT32, the type of `1,2,3,4` is INT64. - // the value of `1,2,4` will be casted to INT32 and pre-calculated + // the data type of c2 is INT8, the type of `1,2,4` is INT64. + // the value of `1,2,4` will be casted to INT8 and pre-calculated // flatten to a single string let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); @@ -392,10 +392,10 @@ async fn csv_explain_inlist_verbose() { &actual, "#aggregate_test_100.c2 IN ([Int64(1), Int64(2), Int64(4)])" ); - // after optimization (casted to Int32) + // after optimization (casted to Int8) assert_contains!( &actual, - "#aggregate_test_100.c2 IN ([Int32(1), Int32(2), Int32(4)])" + "#aggregate_test_100.c2 IN ([Int8(1), Int8(2), Int8(4)])" ); } @@ -420,8 +420,8 @@ async fn csv_explain_verbose_plans() { let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: #aggregate_test_100.c1 [c1:Utf8]", - " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", - " TableScan: aggregate_test_100 [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", + " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", + " TableScan: aggregate_test_100 [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -467,9 +467,9 @@ async fn csv_explain_verbose_plans() { " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", + " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]", " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", + " 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]", " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", " }", "}", @@ -495,8 +495,8 @@ async fn csv_explain_verbose_plans() { let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: #aggregate_test_100.c1 [c1:Utf8]", - " Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]", + " Filter: #aggregate_test_100.c2 > Int8(10) [c1:Utf8, c2:Int8]", + " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)] [c1:Utf8, c2:Int8]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -510,8 +510,8 @@ async fn csv_explain_verbose_plans() { let expected = vec![ "Explain", " Projection: #aggregate_test_100.c1", - " Filter: #aggregate_test_100.c2 > Int32(10)", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]", + " Filter: #aggregate_test_100.c2 > Int8(10)", + " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)]", ]; let formatted = plan.display_indent().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -531,9 +531,9 @@ async fn csv_explain_verbose_plans() { " 2[shape=box label=\"Explain\"]", " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]", + " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int8(10)\"]", " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]", + " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)]\"]", " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", " }", " subgraph cluster_6", @@ -542,9 +542,9 @@ async fn csv_explain_verbose_plans() { " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int8(10)\\nSchema: [c1:Utf8, c2:Int8]\"]", " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)]\\nSchema: [c1:Utf8, c2:Int8]\"]", " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", " }", "}", @@ -781,8 +781,8 @@ async fn csv_explain() { vec![ "logical_plan", "Projection: #aggregate_test_100.c1\ - \n Filter: #aggregate_test_100.c2 > Int32(10)\ - \n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]" + \n Filter: #aggregate_test_100.c2 > Int8(10)\ + \n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)]" ], vec!["physical_plan", "ProjectionExec: expr=[c1@0 as c1]\ diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 131c3c8bcf64f..ef4b4386e0a81 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -82,6 +82,7 @@ macro_rules! test_expression { pub mod aggregates; #[cfg(feature = "avro")] pub mod avro; +pub mod cast; pub mod create_drop; pub mod errors; pub mod explain_analyze; @@ -621,22 +622,20 @@ async fn register_tpch_csv_data( async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { let testdata = datafusion::test_util::arrow_test_data(); - // TODO: The following c9 should be migrated to UInt32 and c10 should be UInt64 once - // unsigned is supported. let df = ctx .sql(&format!( " CREATE EXTERNAL TABLE aggregate_test_100 ( c1 VARCHAR NOT NULL, - c2 INT NOT NULL, + c2 TINYINT NOT NULL, c3 SMALLINT NOT NULL, c4 SMALLINT NOT NULL, c5 INTEGER NOT NULL, c6 BIGINT NOT NULL, c7 SMALLINT NOT NULL, c8 INT NOT NULL, - c9 BIGINT NOT NULL, - c10 VARCHAR NOT NULL, + c9 INT UNSIGNED NOT NULL, + c10 BIGINT UNSIGNED NOT NULL, c11 FLOAT NOT NULL, c12 DOUBLE NOT NULL, c13 VARCHAR NOT NULL diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 81ac69a940c82..5cc388ba01a43 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -519,7 +519,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut fields = Vec::with_capacity(columns.len()); for column in columns { - let data_type = self.make_data_type(&column.data_type)?; + let data_type = convert_simple_data_type(&column.data_type)?; let allow_null = column .options .iter() @@ -534,56 +534,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Schema::new(fields)) } - /// Maps the SQL type to the corresponding Arrow `DataType` - fn make_data_type(&self, sql_type: &SQLDataType) -> Result { - match sql_type { - SQLDataType::BigInt(_) => Ok(DataType::Int64), - SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(DataType::Int32), - SQLDataType::SmallInt(_) => Ok(DataType::Int16), - SQLDataType::Char(_) | SQLDataType::Varchar(_) | SQLDataType::Text => { - Ok(DataType::Utf8) - } - SQLDataType::Decimal(precision, scale) => { - make_decimal_type(*precision, *scale) - } - SQLDataType::Float(_) => Ok(DataType::Float32), - SQLDataType::Real => Ok(DataType::Float32), - SQLDataType::Double => Ok(DataType::Float64), - SQLDataType::Boolean => Ok(DataType::Boolean), - SQLDataType::Date => Ok(DataType::Date32), - SQLDataType::Time => Ok(DataType::Time64(TimeUnit::Nanosecond)), - SQLDataType::Timestamp => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)), - // Explicitly list all other types so that if sqlparser - // adds/changes the `SQLDataType` the compiler will tell us on upgrade - // and avoid bugs like https://github.com/apache/arrow-datafusion/issues/3059 - SQLDataType::Nvarchar(_) - | SQLDataType::Uuid - | SQLDataType::Binary(_) - | SQLDataType::Varbinary(_) - | SQLDataType::Blob(_) - | SQLDataType::TinyInt(_) - | SQLDataType::UnsignedTinyInt(_) - | SQLDataType::UnsignedSmallInt(_) - | SQLDataType::UnsignedInt(_) - | SQLDataType::UnsignedInteger(_) - | SQLDataType::UnsignedBigInt(_) - | SQLDataType::Datetime - | SQLDataType::TimestampTz - | SQLDataType::Interval - | SQLDataType::Regclass - | SQLDataType::String - | SQLDataType::Bytea - | SQLDataType::Custom(_) - | SQLDataType::Array(_) - | SQLDataType::Enum(_) - | SQLDataType::Set(_) - | SQLDataType::Clob(_) => Err(DataFusionError::NotImplemented(format!( - "The SQL data type {:?} is not implemented", - sql_type - ))), - } - } - fn plan_from_tables( &self, from: Vec, @@ -2668,9 +2618,16 @@ fn extract_possible_join_keys( pub fn convert_simple_data_type(sql_type: &SQLDataType) -> Result { match sql_type { SQLDataType::Boolean => Ok(DataType::Boolean), + SQLDataType::TinyInt(_) => Ok(DataType::Int8), SQLDataType::SmallInt(_) => Ok(DataType::Int16), - SQLDataType::Int(_) => Ok(DataType::Int32), + SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(DataType::Int32), SQLDataType::BigInt(_) => Ok(DataType::Int64), + SQLDataType::UnsignedTinyInt(_) => Ok(DataType::UInt8), + SQLDataType::UnsignedSmallInt(_) => Ok(DataType::UInt16), + SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) => { + Ok(DataType::UInt32) + } + SQLDataType::UnsignedBigInt(_) => Ok(DataType::UInt64), SQLDataType::Float(_) => Ok(DataType::Float32), SQLDataType::Real => Ok(DataType::Float32), SQLDataType::Double => Ok(DataType::Float64), @@ -2682,11 +2639,26 @@ pub fn convert_simple_data_type(sql_type: &SQLDataType) -> Result { SQLDataType::Date => Ok(DataType::Date32), SQLDataType::Time => Ok(DataType::Time64(TimeUnit::Nanosecond)), SQLDataType::Decimal(precision, scale) => make_decimal_type(*precision, *scale), - SQLDataType::Binary(_) => Ok(DataType::Binary), SQLDataType::Bytea => Ok(DataType::Binary), - other => Err(DataFusionError::NotImplemented(format!( + // Explicitly list all other types so that if sqlparser + // adds/changes the `SQLDataType` the compiler will tell us on upgrade + // and avoid bugs like https://github.com/apache/arrow-datafusion/issues/3059 + SQLDataType::Nvarchar(_) + | SQLDataType::Uuid + | SQLDataType::Binary(_) + | SQLDataType::Varbinary(_) + | SQLDataType::Blob(_) + | SQLDataType::Datetime + | SQLDataType::TimestampTz + | SQLDataType::Interval + | SQLDataType::Regclass + | SQLDataType::Custom(_) + | SQLDataType::Array(_) + | SQLDataType::Enum(_) + | SQLDataType::Set(_) + | SQLDataType::Clob(_) => Err(DataFusionError::NotImplemented(format!( "Unsupported SQL type {:?}", - other + sql_type ))), } } @@ -2756,6 +2728,15 @@ mod tests { ); } + #[test] + fn test_tinyint() { + quick_test( + "SELECT CAST(6 AS TINYINT)", + "Projection: CAST(Int64(6) AS Int8)\ + \n EmptyRelation", + ); + } + #[test] fn test_int_decimal_scale_larger_precision() { let sql = "SELECT CAST(10 AS DECIMAL(5, 10))"; diff --git a/docs/source/user-guide/sql/data_types.md b/docs/source/user-guide/sql/data_types.md index 3325d4a774c3e..1cd79743bacef 100644 --- a/docs/source/user-guide/sql/data_types.md +++ b/docs/source/user-guide/sql/data_types.md @@ -23,6 +23,7 @@ DataFusion uses Arrow, and thus the Arrow type system, for query execution. The SQL types from [sqlparser-rs](https://github.com/sqlparser-rs/sqlparser-rs/blob/main/src/ast/data_type.rs#L27) are mapped to [Arrow data types](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) according to the following table. +This mapping occurs when defining the schema in a `CREATE EXTERNAL TABLE` command or when performing a SQL `CAST` operation. ## Character Types @@ -34,15 +35,20 @@ are mapped to [Arrow data types](https://docs.rs/arrow/latest/arrow/datatypes/en ## Numeric Types -| SQL DataType | Arrow DataType | -| ------------------ | :---------------- | -| `SMALLINT` | `Int16` | -| `INT` or `INTEGER` | `Int32` | -| `BIGINT` | `Int64` | -| `FLOAT` | `Float32` | -| `REAL` | `Float32` | -| `DOUBLE` | `Float64` | -| `DECIMAL(p,s)` | `Decimal128(p,s)` | +| SQL DataType | Arrow DataType | +| ------------------------------------ | :---------------- | +| `TINYINT` | `Int8` | +| `SMALLINT` | `Int16` | +| `INT` or `INTEGER` | `Int32` | +| `BIGINT` | `Int64` | +| `TINYINT UNSIGNED` | `UInt8` | +| `SMALLINT UNSIGNED` | `UInt16` | +| `INT UNSIGNED` or `INTEGER UNSIGNED` | `UInt32` | +| `BIGINT UNSIGNED` | `UInt64` | +| `FLOAT` | `Float32` | +| `REAL` | `Float32` | +| `DOUBLE` | `Float64` | +| `DECIMAL(p,s)` | `Decimal128(p,s)` | ## Date/Time Types @@ -58,27 +64,27 @@ are mapped to [Arrow data types](https://docs.rs/arrow/latest/arrow/datatypes/en | ------------ | :------------- | | `BOOLEAN` | `Boolean` | +## Binary Types + +| SQL DataType | Arrow DataType | +| ------------ | :------------- | +| `BYTEA` | `Binary` | + ## Unsupported Types -| SQL Data Type | Arrow DataType | -| ------------------- | :------------------ | -| `UUID` | _Not yet supported_ | -| `BLOB` | _Not yet supported_ | -| `CLOB` | _Not yet supported_ | -| `BINARY` | _Not yet supported_ | -| `VARBINARY` | _Not yet supported_ | -| `BYTEA` | _Not yet supported_ | -| `REGCLASS` | _Not yet supported_ | -| `NVARCHAR` | _Not yet supported_ | -| `STRING` | _Not yet supported_ | -| `CUSTOM` | _Not yet supported_ | -| `ARRAY` | _Not yet supported_ | -| `ENUM` | _Not yet supported_ | -| `SET` | _Not yet supported_ | -| `INTERVAL` | _Not yet supported_ | -| `DATETIME` | _Not yet supported_ | -| `TINYINT` | _Not yet supported_ | -| `UNSIGNED TINYINT` | _Not yet supported_ | -| `UNSIGNED SMALLINT` | _Not yet supported_ | -| `UNSIGNED INT` | _Not yet supported_ | -| `UNSIGNED BIGINT` | _Not yet supported_ | +| SQL Data Type | Arrow DataType | +| ------------- | :------------------ | +| `UUID` | _Not yet supported_ | +| `BLOB` | _Not yet supported_ | +| `CLOB` | _Not yet supported_ | +| `BINARY` | _Not yet supported_ | +| `VARBINARY` | _Not yet supported_ | +| `REGCLASS` | _Not yet supported_ | +| `NVARCHAR` | _Not yet supported_ | +| `STRING` | _Not yet supported_ | +| `CUSTOM` | _Not yet supported_ | +| `ARRAY` | _Not yet supported_ | +| `ENUM` | _Not yet supported_ | +| `SET` | _Not yet supported_ | +| `INTERVAL` | _Not yet supported_ | +| `DATETIME` | _Not yet supported_ |