diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 9952f8cae2f..6f953bad139 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -658,7 +658,7 @@ mod tests { let mut ctx = create_ctx(&tmp_dir, partition_count)?; let logical_plan = - ctx.create_logical_plan("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3")?; + ctx.create_logical_plan("SELECT c1, c2 FROM test WHERE CAST(c1 AS double) > 0 AND CAST(c1 AS double) < 3")?; let logical_plan = ctx.optimize(&logical_plan)?; let physical_plan = ctx.create_physical_plan(&logical_plan, 1024)?; diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 704c5482543..f92ea8eea77 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -288,14 +288,14 @@ impl Expr { let this_type = self.get_type(schema)?; if this_type == *cast_to_type { Ok(self.clone()) - } else if can_coerce_from(cast_to_type, &this_type) { + } else if cast_supported(cast_to_type, &this_type) { Ok(Expr::Cast { expr: Box::new(self.clone()), data_type: cast_to_type.clone(), }) } else { Err(ExecutionError::General(format!( - "Cannot automatically convert {:?} to {:?}", + "Cannot cast from {:?} to {:?}", this_type, cast_to_type ))) } @@ -723,23 +723,28 @@ impl fmt::Debug for LogicalPlan { } /// Verify a given type cast can be performed -pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { +pub fn cast_supported(type_into: &DataType, type_from: &DataType) -> bool { use self::DataType::*; + + if type_from == type_into { + return true; + } + match type_into { Int8 => match type_from { Int8 => true, _ => false, }, Int16 => match type_from { - Int8 | Int16 | UInt8 => true, + Int8 | Int16 => true, _ => false, }, Int32 => match type_from { - Int8 | Int16 | Int32 | UInt8 | UInt16 => true, + Int8 | Int16 | Int32 => true, _ => false, }, Int64 => match type_from { - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 => true, + Int8 | Int16 | Int32 | Int64 => true, _ => false, }, UInt8 => match type_from { diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index 03a85ba4c85..349627e33f9 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -187,8 +187,10 @@ mod tests { use crate::execution::context::ExecutionContext; use crate::execution::physical_plan::csv::CsvReadOptions; use crate::logicalplan::Expr::*; - use crate::logicalplan::{col, Operator}; + use crate::logicalplan::{cast_supported, col, Operator}; + use crate::optimizer::utils::get_supertype; use crate::test::arrow_testdata_path; + use arrow::datatypes::DataType::*; use arrow::datatypes::{DataType, Field, Schema}; #[test] @@ -212,6 +214,30 @@ mod tests { Ok(()) } + #[test] + fn test_type_matrix() -> Result<()> { + let types = vec![ + Boolean, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, + Float64, Utf8, + ]; + + for from_type in &types { + for to_type in &types { + match get_supertype(from_type, to_type) { + Ok(t) => { + // swapping from and to should result in same supertype + assert_eq!(t, get_supertype(to_type, from_type)?); + // both from and to types should be coercable to the supertype + assert!(cast_supported(&t, &from_type)); + assert!(cast_supported(&t, &to_type)); + } + Err(_) => assert!(get_supertype(to_type, from_type).is_err()), + } + } + } + Ok(()) + } + #[test] fn test_add_i32_i64() { binary_cast_test( @@ -254,20 +280,6 @@ mod tests { ); } - #[test] - fn test_add_u32_i64() { - binary_cast_test( - DataType::UInt32, - DataType::Int64, - "CAST(#0 AS Int64) Plus #1", - ); - binary_cast_test( - DataType::Int64, - DataType::UInt32, - "#0 Plus CAST(#1 AS Int64)", - ); - } - fn binary_cast_test(left_type: DataType, right_type: DataType, expected: &str) { let schema = Schema::new(vec![ Field::new("c0", left_type, true), @@ -285,6 +297,6 @@ mod tests { let expr2 = rule.rewrite_expr(&expr, &schema).unwrap(); - assert_eq!(expected, format!("{:?}", expr2)); + assert_eq!(format!("{:?}", expr2), expected); } } diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index f9467a12091..56046360ecd 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -131,116 +131,78 @@ pub fn exprlist_to_fields(expr: &[Expr], input_schema: &Schema) -> Result Result { - match _get_supertype(l, r) { - Some(dt) => Ok(dt), - None => _get_supertype(r, l).ok_or_else(|| { - ExecutionError::InternalError(format!( - "Failed to determine supertype of {:?} and {:?}", - l, r - )) - }), - } -} - -/// Given two datatypes, determine the supertype that both types can safely be cast to -fn _get_supertype(l: &DataType, r: &DataType) -> Option { use arrow::datatypes::DataType::*; - match (l, r) { - (UInt8, Int8) => Some(Int8), - (UInt8, Int16) => Some(Int16), - (UInt8, Int32) => Some(Int32), - (UInt8, Int64) => Some(Int64), - - (UInt16, Int16) => Some(Int16), - (UInt16, Int32) => Some(Int32), - (UInt16, Int64) => Some(Int64), - - (UInt32, Int32) => Some(Int32), - (UInt32, Int64) => Some(Int64), - - (UInt64, Int64) => Some(Int64), - - (Int8, UInt8) => Some(Int8), - - (Int16, UInt8) => Some(Int16), - (Int16, UInt16) => Some(Int16), - - (Int32, UInt8) => Some(Int32), - (Int32, UInt16) => Some(Int32), - (Int32, UInt32) => Some(Int32), - - (Int64, UInt8) => Some(Int64), - (Int64, UInt16) => Some(Int64), - (Int64, UInt32) => Some(Int64), - (Int64, UInt64) => Some(Int64), - - (UInt8, UInt8) => Some(UInt8), - (UInt8, UInt16) => Some(UInt16), - (UInt8, UInt32) => Some(UInt32), - (UInt8, UInt64) => Some(UInt64), - (UInt8, Float32) => Some(Float32), - (UInt8, Float64) => Some(Float64), - (UInt16, UInt8) => Some(UInt16), - (UInt16, UInt16) => Some(UInt16), - (UInt16, UInt32) => Some(UInt32), - (UInt16, UInt64) => Some(UInt64), - (UInt16, Float32) => Some(Float32), - (UInt16, Float64) => Some(Float64), - - (UInt32, UInt8) => Some(UInt32), - (UInt32, UInt16) => Some(UInt32), - (UInt32, UInt32) => Some(UInt32), - (UInt32, UInt64) => Some(UInt64), - (UInt32, Float32) => Some(Float32), - (UInt32, Float64) => Some(Float64), - - (UInt64, UInt8) => Some(UInt64), - (UInt64, UInt16) => Some(UInt64), - (UInt64, UInt32) => Some(UInt64), - (UInt64, UInt64) => Some(UInt64), - (UInt64, Float32) => Some(Float32), - (UInt64, Float64) => Some(Float64), - - (Int8, Int8) => Some(Int8), - (Int8, Int16) => Some(Int16), - (Int8, Int32) => Some(Int32), - (Int8, Int64) => Some(Int64), - (Int8, Float32) => Some(Float32), - (Int8, Float64) => Some(Float64), - - (Int16, Int8) => Some(Int16), - (Int16, Int16) => Some(Int16), - (Int16, Int32) => Some(Int32), - (Int16, Int64) => Some(Int64), - (Int16, Float32) => Some(Float32), - (Int16, Float64) => Some(Float64), - - (Int32, Int8) => Some(Int32), - (Int32, Int16) => Some(Int32), - (Int32, Int32) => Some(Int32), - (Int32, Int64) => Some(Int64), - (Int32, Float32) => Some(Float32), - (Int32, Float64) => Some(Float64), - - (Int64, Int8) => Some(Int64), - (Int64, Int16) => Some(Int64), - (Int64, Int32) => Some(Int64), - (Int64, Int64) => Some(Int64), - (Int64, Float32) => Some(Float32), - (Int64, Float64) => Some(Float64), - - (Float32, Float32) => Some(Float32), - (Float32, Float64) => Some(Float64), - (Float64, Float32) => Some(Float64), - (Float64, Float64) => Some(Float64), - - (Utf8, _) => Some(Utf8), - (_, Utf8) => Some(Utf8), - - (Boolean, Boolean) => Some(Boolean), + if l == r { + return Ok(l.clone()); + } + let super_type = match l { + UInt8 => match r { + UInt16 | UInt32 | UInt64 => Some(r.clone()), + Float32 | Float64 => Some(r.clone()), + _ => None, + }, + UInt16 => match r { + UInt8 => Some(l.clone()), + UInt32 | UInt64 => Some(r.clone()), + Float32 | Float64 => Some(r.clone()), + _ => None, + }, + UInt32 => match r { + UInt8 | UInt16 => Some(l.clone()), + UInt64 => Some(r.clone()), + Float32 | Float64 => Some(r.clone()), + _ => None, + }, + UInt64 => match r { + UInt8 | UInt16 | UInt32 => Some(l.clone()), + Float32 | Float64 => Some(r.clone()), + _ => None, + }, + Int8 => match r { + Int16 | Int32 | Int64 => Some(r.clone()), + Float32 | Float64 => Some(r.clone()), + _ => None, + }, + Int16 => match r { + Int8 => Some(l.clone()), + Int32 | Int64 => Some(r.clone()), + Float32 | Float64 => Some(r.clone()), + _ => None, + }, + Int32 => match r { + Int8 | Int16 => Some(l.clone()), + Int64 => Some(r.clone()), + Float32 | Float64 => Some(r.clone()), + _ => None, + }, + Int64 => match r { + Int8 | Int16 | Int32 => Some(l.clone()), + Float32 | Float64 => Some(r.clone()), + _ => None, + }, + Float32 => match r { + Int8 | Int16 | Int32 | Int64 => Some(Float32), + UInt8 | UInt16 | UInt32 | UInt64 => Some(Float32), + Float64 => Some(Float64), + _ => None, + }, + Float64 => match r { + Int8 | Int16 | Int32 | Int64 => Some(Float64), + UInt8 | UInt16 | UInt32 | UInt64 => Some(Float64), + Float32 | Float64 => Some(Float64), + _ => None, + }, _ => None, + }; + + match super_type { + Some(dt) => Ok(dt), + None => Err(ExecutionError::InternalError(format!( + "Failed to determine supertype of {:?} and {:?}", + l, r + ))), } }