diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index 10f6a446f7d..efffef88767 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -98,117 +98,117 @@ pub fn exprlist_to_fields(expr: &Vec, input_schema: &Schema) -> Result Result { - match _get_supertype(l, r) { - Some(dt) => Ok(dt), - None => match _get_supertype(r, l) { - Some(dt) => Ok(dt), - None => Err(ExecutionError::InternalError(format!( - "Failed to determine supertype of {:?} and {:?}", - l, r - ))), - }, + use arrow::datatypes::DataType::*; + let d = if l == r { + Some(l.clone()) + } else if l == &Utf8 || r == &Utf8 { + Some(Utf8) + } else if is_signed_int(l) { + if is_signed_int(r) { + if bit_width(l) >= bit_width(r) { + Some(l.clone()) + } else { + Some(r.clone()) + } + } else if is_unsigned_int(r) { + match bit_width(l).max(bit_width(r)) { + 8 => Some(Int8), + 16 => Some(Int16), + 32 => Some(Int32), + 64 => Some(Int64), + _ => None, + } + } else if is_floating_point(r) { + Some(r.clone()) + } else { + None + } + } else if is_unsigned_int(l) { + if is_signed_int(r) { + match bit_width(l).max(bit_width(r)) { + 8 => Some(Int8), + 16 => Some(Int16), + 32 => Some(Int32), + 64 => Some(Int64), + _ => None, + } + } else if is_unsigned_int(r) { + if bit_width(l) >= bit_width(r) { + Some(l.clone()) + } else { + Some(r.clone()) + } + } else if is_floating_point(r) { + Some(r.clone()) + } else { + None + } + } else if is_floating_point(l) { + if is_int(r) { + Some(l.clone()) + } else if is_floating_point(r) { + if bit_width(l) >= bit_width(r) { + Some(l.clone()) + } else { + Some(r.clone()) + } + } else { + None + } + } else { + None + }; + + match d { + Some(dd) => Ok(dd), + None => Err(ExecutionError::General(format!( + "Could not determine supertype for {:?} and {:?}", + l, r + ))), } } -/// Given two datatypes, determine the supertype that both types can safely be cast to -fn _get_supertype(l: &DataType, r: &DataType) -> Option { +fn is_int(d: &DataType) -> bool { + use arrow::datatypes::DataType::*; + match d { + Int8 | Int16 | Int32 | Int64 => true, + UInt8 | UInt16 | UInt32 | UInt64 => true, + _ => false, + } +} + +fn is_signed_int(d: &DataType) -> bool { + use arrow::datatypes::DataType::*; + match d { + Int8 | Int16 | Int32 | Int64 => true, + _ => false, + } +} + +fn is_unsigned_int(d: &DataType) -> bool { use arrow::datatypes::DataType::*; - match (l, r) { - (UInt8, Int8) => Some(Int8), - (UInt8, Int16) => Some(Int16), - (UInt8, Int32) => Some(Int32), - (UInt8, Int64) => Some(Int64), - - (UInt16, Int16) => Some(Int16), - (UInt16, Int32) => Some(Int32), - (UInt16, Int64) => Some(Int64), - - (UInt32, Int32) => Some(Int32), - (UInt32, Int64) => Some(Int64), - - (UInt64, Int64) => Some(Int64), - - (Int8, UInt8) => Some(Int8), - - (Int16, UInt8) => Some(Int16), - (Int16, UInt16) => Some(Int16), - - (Int32, UInt8) => Some(Int32), - (Int32, UInt16) => Some(Int32), - (Int32, UInt32) => Some(Int32), - - (Int64, UInt8) => Some(Int64), - (Int64, UInt16) => Some(Int64), - (Int64, UInt32) => Some(Int64), - (Int64, UInt64) => Some(Int64), - - (UInt8, UInt8) => Some(UInt8), - (UInt8, UInt16) => Some(UInt16), - (UInt8, UInt32) => Some(UInt32), - (UInt8, UInt64) => Some(UInt64), - (UInt8, Float32) => Some(Float32), - (UInt8, Float64) => Some(Float64), - - (UInt16, UInt8) => Some(UInt16), - (UInt16, UInt16) => Some(UInt16), - (UInt16, UInt32) => Some(UInt32), - (UInt16, UInt64) => Some(UInt64), - (UInt16, Float32) => Some(Float32), - (UInt16, Float64) => Some(Float64), - - (UInt32, UInt8) => Some(UInt32), - (UInt32, UInt16) => Some(UInt32), - (UInt32, UInt32) => Some(UInt32), - (UInt32, UInt64) => Some(UInt64), - (UInt32, Float32) => Some(Float32), - (UInt32, Float64) => Some(Float64), - - (UInt64, UInt8) => Some(UInt64), - (UInt64, UInt16) => Some(UInt64), - (UInt64, UInt32) => Some(UInt64), - (UInt64, UInt64) => Some(UInt64), - (UInt64, Float32) => Some(Float32), - (UInt64, Float64) => Some(Float64), - - (Int8, Int8) => Some(Int8), - (Int8, Int16) => Some(Int16), - (Int8, Int32) => Some(Int32), - (Int8, Int64) => Some(Int64), - (Int8, Float32) => Some(Float32), - (Int8, Float64) => Some(Float64), - - (Int16, Int8) => Some(Int16), - (Int16, Int16) => Some(Int16), - (Int16, Int32) => Some(Int32), - (Int16, Int64) => Some(Int64), - (Int16, Float32) => Some(Float32), - (Int16, Float64) => Some(Float64), - - (Int32, Int8) => Some(Int32), - (Int32, Int16) => Some(Int32), - (Int32, Int32) => Some(Int32), - (Int32, Int64) => Some(Int64), - (Int32, Float32) => Some(Float32), - (Int32, Float64) => Some(Float64), - - (Int64, Int8) => Some(Int64), - (Int64, Int16) => Some(Int64), - (Int64, Int32) => Some(Int64), - (Int64, Int64) => Some(Int64), - (Int64, Float32) => Some(Float32), - (Int64, Float64) => Some(Float64), - - (Float32, Float32) => Some(Float32), - (Float32, Float64) => Some(Float64), - (Float64, Float32) => Some(Float64), - (Float64, Float64) => Some(Float64), - - (Utf8, _) => Some(Utf8), - (_, Utf8) => Some(Utf8), - - (Boolean, Boolean) => Some(Boolean), - - _ => None, + match d { + UInt8 | UInt16 | UInt32 | UInt64 => true, + _ => false, + } +} + +fn is_floating_point(d: &DataType) -> bool { + use arrow::datatypes::DataType::*; + match d { + Float32 | Float64 => true, + _ => false, + } +} + +fn bit_width(d: &DataType) -> usize { + use arrow::datatypes::DataType::*; + match d { + Int8 | UInt8 => 8, + Int16 | UInt16 => 16, + Int32 | UInt32 | Float32 => 32, + Int64 | UInt64 | Float64 => 64, + _ => 0, } } @@ -217,9 +217,133 @@ mod tests { use super::*; use crate::logicalplan::Expr; use arrow::datatypes::DataType; + use arrow::datatypes::DataType::*; use std::collections::HashSet; use std::sync::Arc; + #[test] + fn test_supertype_numeric_types() { + let types = vec![ + UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64, + ]; + let mut result = String::new(); + for l in &types { + for r in &types { + result.push_str(&format!( + "supertype of {:?} and {:?} is {:?} +\n", + l, + r, + get_supertype(l, r).unwrap() + )); + } + } + assert_eq!( + "supertype of UInt8 and UInt8 is UInt8 +\nsupertype of UInt8 and UInt16 is UInt16 +\nsupertype of UInt8 and UInt32 is UInt32 +\nsupertype of UInt8 and UInt64 is UInt64 +\nsupertype of UInt8 and Int8 is Int8 +\nsupertype of UInt8 and Int16 is Int16 +\nsupertype of UInt8 and Int32 is Int32 +\nsupertype of UInt8 and Int64 is Int64 +\nsupertype of UInt8 and Float32 is Float32 +\nsupertype of UInt8 and Float64 is Float64 +\nsupertype of UInt16 and UInt8 is UInt16 +\nsupertype of UInt16 and UInt16 is UInt16 +\nsupertype of UInt16 and UInt32 is UInt32 +\nsupertype of UInt16 and UInt64 is UInt64 +\nsupertype of UInt16 and Int8 is Int16 +\nsupertype of UInt16 and Int16 is Int16 +\nsupertype of UInt16 and Int32 is Int32 +\nsupertype of UInt16 and Int64 is Int64 +\nsupertype of UInt16 and Float32 is Float32 +\nsupertype of UInt16 and Float64 is Float64 +\nsupertype of UInt32 and UInt8 is UInt32 +\nsupertype of UInt32 and UInt16 is UInt32 +\nsupertype of UInt32 and UInt32 is UInt32 +\nsupertype of UInt32 and UInt64 is UInt64 +\nsupertype of UInt32 and Int8 is Int32 +\nsupertype of UInt32 and Int16 is Int32 +\nsupertype of UInt32 and Int32 is Int32 +\nsupertype of UInt32 and Int64 is Int64 +\nsupertype of UInt32 and Float32 is Float32 +\nsupertype of UInt32 and Float64 is Float64 +\nsupertype of UInt64 and UInt8 is UInt64 +\nsupertype of UInt64 and UInt16 is UInt64 +\nsupertype of UInt64 and UInt32 is UInt64 +\nsupertype of UInt64 and UInt64 is UInt64 +\nsupertype of UInt64 and Int8 is Int64 +\nsupertype of UInt64 and Int16 is Int64 +\nsupertype of UInt64 and Int32 is Int64 +\nsupertype of UInt64 and Int64 is Int64 +\nsupertype of UInt64 and Float32 is Float32 +\nsupertype of UInt64 and Float64 is Float64 +\nsupertype of Int8 and UInt8 is Int8 +\nsupertype of Int8 and UInt16 is Int16 +\nsupertype of Int8 and UInt32 is Int32 +\nsupertype of Int8 and UInt64 is Int64 +\nsupertype of Int8 and Int8 is Int8 +\nsupertype of Int8 and Int16 is Int16 +\nsupertype of Int8 and Int32 is Int32 +\nsupertype of Int8 and Int64 is Int64 +\nsupertype of Int8 and Float32 is Float32 +\nsupertype of Int8 and Float64 is Float64 +\nsupertype of Int16 and UInt8 is Int16 +\nsupertype of Int16 and UInt16 is Int16 +\nsupertype of Int16 and UInt32 is Int32 +\nsupertype of Int16 and UInt64 is Int64 +\nsupertype of Int16 and Int8 is Int16 +\nsupertype of Int16 and Int16 is Int16 +\nsupertype of Int16 and Int32 is Int32 +\nsupertype of Int16 and Int64 is Int64 +\nsupertype of Int16 and Float32 is Float32 +\nsupertype of Int16 and Float64 is Float64 +\nsupertype of Int32 and UInt8 is Int32 +\nsupertype of Int32 and UInt16 is Int32 +\nsupertype of Int32 and UInt32 is Int32 +\nsupertype of Int32 and UInt64 is Int64 +\nsupertype of Int32 and Int8 is Int32 +\nsupertype of Int32 and Int16 is Int32 +\nsupertype of Int32 and Int32 is Int32 +\nsupertype of Int32 and Int64 is Int64 +\nsupertype of Int32 and Float32 is Float32 +\nsupertype of Int32 and Float64 is Float64 +\nsupertype of Int64 and UInt8 is Int64 +\nsupertype of Int64 and UInt16 is Int64 +\nsupertype of Int64 and UInt32 is Int64 +\nsupertype of Int64 and UInt64 is Int64 +\nsupertype of Int64 and Int8 is Int64 +\nsupertype of Int64 and Int16 is Int64 +\nsupertype of Int64 and Int32 is Int64 +\nsupertype of Int64 and Int64 is Int64 +\nsupertype of Int64 and Float32 is Float32 +\nsupertype of Int64 and Float64 is Float64 +\nsupertype of Float32 and UInt8 is Float32 +\nsupertype of Float32 and UInt16 is Float32 +\nsupertype of Float32 and UInt32 is Float32 +\nsupertype of Float32 and UInt64 is Float32 +\nsupertype of Float32 and Int8 is Float32 +\nsupertype of Float32 and Int16 is Float32 +\nsupertype of Float32 and Int32 is Float32 +\nsupertype of Float32 and Int64 is Float32 +\nsupertype of Float32 and Float32 is Float32 +\nsupertype of Float32 and Float64 is Float64 +\nsupertype of Float64 and UInt8 is Float64 +\nsupertype of Float64 and UInt16 is Float64 +\nsupertype of Float64 and UInt32 is Float64 +\nsupertype of Float64 and UInt64 is Float64 +\nsupertype of Float64 and Int8 is Float64 +\nsupertype of Float64 and Int16 is Float64 +\nsupertype of Float64 and Int32 is Float64 +\nsupertype of Float64 and Int64 is Float64 +\nsupertype of Float64 and Float32 is Float64 +\nsupertype of Float64 and Float64 is Float64 +\n", + result + ); + } + #[test] fn test_collect_expr() { let mut accum: HashSet = HashSet::new();