From 0a3bf55c9b5a5b3cae2984309f9a3f9136f4f77c Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 8 Nov 2022 11:01:35 -0500 Subject: [PATCH] Add additional testing for `unwrap_cast_in_comparison` --- .../src/unwrap_cast_in_comparison.rs | 196 +++++++++++++++++- 1 file changed, 195 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 3dfbaa028187a..c694b1e42e5c6 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -379,8 +379,10 @@ fn try_cast_literal_to_type( #[cfg(test)] mod tests { + use super::*; use crate::unwrap_cast_in_comparison::UnwrapCastExprRewriter; - use arrow::datatypes::DataType; + use arrow::compute::{cast_with_options, CastOptions}; + use arrow::datatypes::{DataType, Field}; use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue}; use datafusion_expr::expr_rewriter::ExprRewritable; use datafusion_expr::{cast, col, in_list, lit, try_cast, Expr}; @@ -653,4 +655,196 @@ mod tests { fn null_decimal(precision: u8, scale: u8) -> Expr { lit(ScalarValue::Decimal128(None, precision, scale)) } + + #[test] + fn test_try_cast_to_type_nulls() { + // test values that can be cast to/from all integer types + let scalars = vec![ + ScalarValue::Int8(None), + ScalarValue::Int16(None), + ScalarValue::Int32(None), + ScalarValue::Int64(None), + ScalarValue::Decimal128(None, 3, 0), + ScalarValue::Decimal128(None, 8, 2), + ]; + + for s1 in &scalars { + for s2 in &scalars { + expect_cast( + s1.clone(), + s2.get_datatype(), + ExpectedCast::Value(s2.clone()), + ); + } + } + } + + #[test] + fn test_try_cast_to_type_int_in_range() { + // test values that can be cast to/from all integer types + let scalars = vec![ + ScalarValue::Int8(Some(123)), + ScalarValue::Int16(Some(123)), + ScalarValue::Int32(Some(123)), + ScalarValue::Int64(Some(123)), + ScalarValue::Decimal128(Some(123), 3, 0), + ScalarValue::Decimal128(Some(12300), 8, 2), + ]; + + for s1 in &scalars { + for s2 in &scalars { + expect_cast( + s1.clone(), + s2.get_datatype(), + ExpectedCast::Value(s2.clone()), + ); + } + } + } + + #[test] + fn test_try_cast_to_type_int_out_of_range() { + let max_i64 = ScalarValue::Int64(Some(i64::MAX)); + let max_u64 = ScalarValue::UInt64(Some(u64::MAX)); + expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue); + + expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue); + + expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue); + + expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue); + + // decimal out of range + expect_cast( + ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0), + DataType::Int64, + ExpectedCast::NoValue, + ); + + expect_cast( + ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 37, 1), + DataType::Int64, + ExpectedCast::NoValue, + ); + } + + #[test] + fn test_try_decimal_cast_in_range() { + expect_cast( + ScalarValue::Decimal128(Some(12300), 5, 2), + DataType::Decimal128(3, 0), + ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)), + ); + + expect_cast( + ScalarValue::Decimal128(Some(12300), 5, 2), + DataType::Decimal128(8, 0), + ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)), + ); + + expect_cast( + ScalarValue::Decimal128(Some(12300), 5, 2), + DataType::Decimal128(8, 5), + ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)), + ); + } + + #[test] + fn test_try_decimal_cast_out_of_range() { + // decimal would lose precision + expect_cast( + ScalarValue::Decimal128(Some(12345), 5, 2), + DataType::Decimal128(3, 0), + ExpectedCast::NoValue, + ); + + // decimal would lose precision + expect_cast( + ScalarValue::Decimal128(Some(12300), 5, 2), + DataType::Decimal128(2, 0), + ExpectedCast::NoValue, + ); + } + + #[test] + fn test_try_cast_to_type_unsupported() { + // int64 to list + expect_cast( + ScalarValue::Int64(Some(12345)), + DataType::List(Box::new(Field::new("f", DataType::Int32, true))), + ExpectedCast::NoValue, + ); + } + + #[derive(Debug, Clone)] + enum ExpectedCast { + /// test successfully cast value and it is as specified + Value(ScalarValue), + /// test returned OK, but could not cast the value + NoValue, + } + + /// Runs try_cast_literal_to_type with the specified inputs and + /// ensure it computes the expected output, and ensures the + /// casting is consistent with the Arrow kernels + fn expect_cast( + literal: ScalarValue, + target_type: DataType, + expected_result: ExpectedCast, + ) { + let actual_result = try_cast_literal_to_type(&literal, &target_type); + + println!("expect_cast: "); + println!(" {:?} --> {:?}", literal, target_type); + println!(" expected_result: {:?}", expected_result); + println!(" actual_result: {:?}", actual_result); + + match expected_result { + ExpectedCast::Value(expected_value) => { + let actual_value = actual_result + .expect("Expected success but got error") + .expect("Expected cast value but got None"); + + assert_eq!(actual_value, expected_value); + + // Verify that calling the arrow + // cast kernel yields the same results + // input array + let literal_array = literal.to_array_of_size(1); + let expected_array = expected_value.to_array_of_size(1); + let cast_array = cast_with_options( + &literal_array, + &target_type, + &CastOptions { safe: true }, + ) + .expect("Expected to be cast array with arrow cast kernel"); + + assert_eq!( + &expected_array, &cast_array, + "Result of casing {:?} with arrow was\n {:#?}\nbut expected\n{:#?}", + literal, cast_array, expected_array + ); + + // Verify that for timestamp types the timezones are the same + // (ScalarValue::cmp doesn't account for timezones); + if let ( + DataType::Timestamp(left_unit, left_tz), + DataType::Timestamp(right_unit, right_tz), + ) = (actual_value.get_datatype(), expected_value.get_datatype()) + { + assert_eq!(left_unit, right_unit); + assert_eq!(left_tz, right_tz); + } + } + ExpectedCast::NoValue => { + let actual_value = actual_result.expect("Expected success but got error"); + + assert!( + actual_value.is_none(), + "Expected no cast value, but got {:?}", + actual_value + ); + } + } + } }