Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
338 changes: 231 additions & 107 deletions rust/datafusion/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,117 +98,117 @@ pub fn exprlist_to_fields(expr: &Vec<Expr>, input_schema: &Schema) -> Result<Vec

/// Given two datatypes, determine the supertype that both types can safely be cast to
pub fn get_supertype(l: &DataType, r: &DataType) -> Result<DataType> {
match _get_supertype(l, r) {
Some(dt) => Ok(dt),
None => match _get_supertype(r, l) {
Some(dt) => Ok(dt),
None => Err(ExecutionError::InternalError(format!(
"Failed to determine supertype of {:?} and {:?}",
l, r
))),
},
use arrow::datatypes::DataType::*;
let d = if l == r {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the old implementation the super type of UInt16 and Int8 is None but now it is Int16. Why is that? Also do you know how C++ handles this case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not statically compatible in C++ (safely). https://github.com/apache/arrow/blob/master/cpp/src/arrow/compute/kernels/cast.cc#L156-L208

It can be safe to cast, but it needs a runtime check.

Some(l.clone())
} else if l == &Utf8 || r == &Utf8 {
Some(Utf8)
} else if is_signed_int(l) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block is hard to follow, I'd rewrite this with small functions.

if is_signed_int(r) {
if bit_width(l) >= bit_width(r) {
Some(l.clone())
} else {
Some(r.clone())
}
} else if is_unsigned_int(r) {
match bit_width(l).max(bit_width(r)) {
8 => Some(Int8),
16 => Some(Int16),
32 => Some(Int32),
64 => Some(Int64),
_ => None,
}
} else if is_floating_point(r) {
Some(r.clone())
} else {
None
}
} else if is_unsigned_int(l) {
if is_signed_int(r) {
match bit_width(l).max(bit_width(r)) {
8 => Some(Int8),
16 => Some(Int16),
32 => Some(Int32),
64 => Some(Int64),
_ => None,
}
} else if is_unsigned_int(r) {
if bit_width(l) >= bit_width(r) {
Some(l.clone())
} else {
Some(r.clone())
}
} else if is_floating_point(r) {
Some(r.clone())
} else {
None
}
} else if is_floating_point(l) {
if is_int(r) {
Some(l.clone())
} else if is_floating_point(r) {
if bit_width(l) >= bit_width(r) {
Some(l.clone())
} else {
Some(r.clone())
}
} else {
None
}
} else {
None
};

match d {
Some(dd) => Ok(dd),
None => Err(ExecutionError::General(format!(
"Could not determine supertype for {:?} and {:?}",
l, r
))),
}
}

/// Given two datatypes, determine the supertype that both types can safely be cast to
fn _get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
fn is_int(d: &DataType) -> bool {
use arrow::datatypes::DataType::*;
match d {
Int8 | Int16 | Int32 | Int64 => true,
UInt8 | UInt16 | UInt32 | UInt64 => true,
_ => false,
}
}

fn is_signed_int(d: &DataType) -> bool {
use arrow::datatypes::DataType::*;
match d {
Int8 | Int16 | Int32 | Int64 => true,
_ => false,
}
}

fn is_unsigned_int(d: &DataType) -> bool {
use arrow::datatypes::DataType::*;
match (l, r) {
(UInt8, Int8) => Some(Int8),
(UInt8, Int16) => Some(Int16),
(UInt8, Int32) => Some(Int32),
(UInt8, Int64) => Some(Int64),

(UInt16, Int16) => Some(Int16),
(UInt16, Int32) => Some(Int32),
(UInt16, Int64) => Some(Int64),

(UInt32, Int32) => Some(Int32),
(UInt32, Int64) => Some(Int64),

(UInt64, Int64) => Some(Int64),

(Int8, UInt8) => Some(Int8),

(Int16, UInt8) => Some(Int16),
(Int16, UInt16) => Some(Int16),

(Int32, UInt8) => Some(Int32),
(Int32, UInt16) => Some(Int32),
(Int32, UInt32) => Some(Int32),

(Int64, UInt8) => Some(Int64),
(Int64, UInt16) => Some(Int64),
(Int64, UInt32) => Some(Int64),
(Int64, UInt64) => Some(Int64),

(UInt8, UInt8) => Some(UInt8),
(UInt8, UInt16) => Some(UInt16),
(UInt8, UInt32) => Some(UInt32),
(UInt8, UInt64) => Some(UInt64),
(UInt8, Float32) => Some(Float32),
(UInt8, Float64) => Some(Float64),

(UInt16, UInt8) => Some(UInt16),
(UInt16, UInt16) => Some(UInt16),
(UInt16, UInt32) => Some(UInt32),
(UInt16, UInt64) => Some(UInt64),
(UInt16, Float32) => Some(Float32),
(UInt16, Float64) => Some(Float64),

(UInt32, UInt8) => Some(UInt32),
(UInt32, UInt16) => Some(UInt32),
(UInt32, UInt32) => Some(UInt32),
(UInt32, UInt64) => Some(UInt64),
(UInt32, Float32) => Some(Float32),
(UInt32, Float64) => Some(Float64),

(UInt64, UInt8) => Some(UInt64),
(UInt64, UInt16) => Some(UInt64),
(UInt64, UInt32) => Some(UInt64),
(UInt64, UInt64) => Some(UInt64),
(UInt64, Float32) => Some(Float32),
(UInt64, Float64) => Some(Float64),

(Int8, Int8) => Some(Int8),
(Int8, Int16) => Some(Int16),
(Int8, Int32) => Some(Int32),
(Int8, Int64) => Some(Int64),
(Int8, Float32) => Some(Float32),
(Int8, Float64) => Some(Float64),

(Int16, Int8) => Some(Int16),
(Int16, Int16) => Some(Int16),
(Int16, Int32) => Some(Int32),
(Int16, Int64) => Some(Int64),
(Int16, Float32) => Some(Float32),
(Int16, Float64) => Some(Float64),

(Int32, Int8) => Some(Int32),
(Int32, Int16) => Some(Int32),
(Int32, Int32) => Some(Int32),
(Int32, Int64) => Some(Int64),
(Int32, Float32) => Some(Float32),
(Int32, Float64) => Some(Float64),

(Int64, Int8) => Some(Int64),
(Int64, Int16) => Some(Int64),
(Int64, Int32) => Some(Int64),
(Int64, Int64) => Some(Int64),
(Int64, Float32) => Some(Float32),
(Int64, Float64) => Some(Float64),

(Float32, Float32) => Some(Float32),
(Float32, Float64) => Some(Float64),
(Float64, Float32) => Some(Float64),
(Float64, Float64) => Some(Float64),

(Utf8, _) => Some(Utf8),
(_, Utf8) => Some(Utf8),

(Boolean, Boolean) => Some(Boolean),

_ => None,
match d {
UInt8 | UInt16 | UInt32 | UInt64 => true,
_ => false,
}
}

fn is_floating_point(d: &DataType) -> bool {
use arrow::datatypes::DataType::*;
match d {
Float32 | Float64 => true,
_ => false,
}
}

fn bit_width(d: &DataType) -> usize {
use arrow::datatypes::DataType::*;
match d {
Int8 | UInt8 => 8,
Int16 | UInt16 => 16,
Int32 | UInt32 | Float32 => 32,
Int64 | UInt64 | Float64 => 64,
_ => 0,
}
}

Expand All @@ -217,9 +217,133 @@ mod tests {
use super::*;
use crate::logicalplan::Expr;
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::*;
use std::collections::HashSet;
use std::sync::Arc;

#[test]
fn test_supertype_numeric_types() {
let types = vec![
UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64,
];
let mut result = String::new();
for l in &types {
for r in &types {
result.push_str(&format!(
"supertype of {:?} and {:?} is {:?}
\n",
l,
r,
get_supertype(l, r).unwrap()
));
}
}
assert_eq!(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid comparing with string? whenever there's an error it is very hard to find out since you are comparing two huge strings.

"supertype of UInt8 and UInt8 is UInt8
\nsupertype of UInt8 and UInt16 is UInt16
\nsupertype of UInt8 and UInt32 is UInt32
\nsupertype of UInt8 and UInt64 is UInt64
\nsupertype of UInt8 and Int8 is Int8
\nsupertype of UInt8 and Int16 is Int16
\nsupertype of UInt8 and Int32 is Int32
\nsupertype of UInt8 and Int64 is Int64
\nsupertype of UInt8 and Float32 is Float32
\nsupertype of UInt8 and Float64 is Float64
\nsupertype of UInt16 and UInt8 is UInt16
\nsupertype of UInt16 and UInt16 is UInt16
\nsupertype of UInt16 and UInt32 is UInt32
\nsupertype of UInt16 and UInt64 is UInt64
\nsupertype of UInt16 and Int8 is Int16
\nsupertype of UInt16 and Int16 is Int16
\nsupertype of UInt16 and Int32 is Int32
\nsupertype of UInt16 and Int64 is Int64
\nsupertype of UInt16 and Float32 is Float32
\nsupertype of UInt16 and Float64 is Float64
\nsupertype of UInt32 and UInt8 is UInt32
\nsupertype of UInt32 and UInt16 is UInt32
\nsupertype of UInt32 and UInt32 is UInt32
\nsupertype of UInt32 and UInt64 is UInt64
\nsupertype of UInt32 and Int8 is Int32
\nsupertype of UInt32 and Int16 is Int32
\nsupertype of UInt32 and Int32 is Int32
\nsupertype of UInt32 and Int64 is Int64
\nsupertype of UInt32 and Float32 is Float32
\nsupertype of UInt32 and Float64 is Float64
\nsupertype of UInt64 and UInt8 is UInt64
\nsupertype of UInt64 and UInt16 is UInt64
\nsupertype of UInt64 and UInt32 is UInt64
\nsupertype of UInt64 and UInt64 is UInt64
\nsupertype of UInt64 and Int8 is Int64
\nsupertype of UInt64 and Int16 is Int64
\nsupertype of UInt64 and Int32 is Int64
\nsupertype of UInt64 and Int64 is Int64
\nsupertype of UInt64 and Float32 is Float32
\nsupertype of UInt64 and Float64 is Float64
\nsupertype of Int8 and UInt8 is Int8
\nsupertype of Int8 and UInt16 is Int16
\nsupertype of Int8 and UInt32 is Int32
\nsupertype of Int8 and UInt64 is Int64
\nsupertype of Int8 and Int8 is Int8
\nsupertype of Int8 and Int16 is Int16
\nsupertype of Int8 and Int32 is Int32
\nsupertype of Int8 and Int64 is Int64
\nsupertype of Int8 and Float32 is Float32
\nsupertype of Int8 and Float64 is Float64
\nsupertype of Int16 and UInt8 is Int16
\nsupertype of Int16 and UInt16 is Int16
\nsupertype of Int16 and UInt32 is Int32
\nsupertype of Int16 and UInt64 is Int64
\nsupertype of Int16 and Int8 is Int16
\nsupertype of Int16 and Int16 is Int16
\nsupertype of Int16 and Int32 is Int32
\nsupertype of Int16 and Int64 is Int64
\nsupertype of Int16 and Float32 is Float32
\nsupertype of Int16 and Float64 is Float64
\nsupertype of Int32 and UInt8 is Int32
\nsupertype of Int32 and UInt16 is Int32
\nsupertype of Int32 and UInt32 is Int32
\nsupertype of Int32 and UInt64 is Int64
\nsupertype of Int32 and Int8 is Int32
\nsupertype of Int32 and Int16 is Int32
\nsupertype of Int32 and Int32 is Int32
\nsupertype of Int32 and Int64 is Int64
\nsupertype of Int32 and Float32 is Float32
\nsupertype of Int32 and Float64 is Float64
\nsupertype of Int64 and UInt8 is Int64
\nsupertype of Int64 and UInt16 is Int64
\nsupertype of Int64 and UInt32 is Int64
\nsupertype of Int64 and UInt64 is Int64
\nsupertype of Int64 and Int8 is Int64
\nsupertype of Int64 and Int16 is Int64
\nsupertype of Int64 and Int32 is Int64
\nsupertype of Int64 and Int64 is Int64
\nsupertype of Int64 and Float32 is Float32
\nsupertype of Int64 and Float64 is Float64
\nsupertype of Float32 and UInt8 is Float32
\nsupertype of Float32 and UInt16 is Float32
\nsupertype of Float32 and UInt32 is Float32
\nsupertype of Float32 and UInt64 is Float32
\nsupertype of Float32 and Int8 is Float32
\nsupertype of Float32 and Int16 is Float32
\nsupertype of Float32 and Int32 is Float32
\nsupertype of Float32 and Int64 is Float32
\nsupertype of Float32 and Float32 is Float32
\nsupertype of Float32 and Float64 is Float64
\nsupertype of Float64 and UInt8 is Float64
\nsupertype of Float64 and UInt16 is Float64
\nsupertype of Float64 and UInt32 is Float64
\nsupertype of Float64 and UInt64 is Float64
\nsupertype of Float64 and Int8 is Float64
\nsupertype of Float64 and Int16 is Float64
\nsupertype of Float64 and Int32 is Float64
\nsupertype of Float64 and Int64 is Float64
\nsupertype of Float64 and Float32 is Float64
\nsupertype of Float64 and Float64 is Float64
\n",
result
);
}

#[test]
fn test_collect_expr() {
let mut accum: HashSet<usize> = HashSet::new();
Expand Down