From 9ed2591d31f784ba1634a6d02bc3d8b3f288c1bf Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 8 Nov 2022 11:01:35 -0500 Subject: [PATCH 1/3] Add support for timestamp casts in unwrap_cast_in_comparison optimzier pass --- .../src/unwrap_cast_in_comparison.rs | 153 +++++++++++++++++- .../optimizer/tests/integration-test.rs | 3 +- 2 files changed, 154 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 7ac91ae3cbf9b..bcce98ecf7364 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -21,7 +21,7 @@ use crate::utils::rewrite_preserving_name; use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{ - DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, + DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast}; @@ -288,6 +288,7 @@ fn is_support_data_type(data_type: &DataType) -> bool { | DataType::Int32 | DataType::Int64 | DataType::Decimal128(_, _) + | DataType::Timestamp(_, _) ) } @@ -306,6 +307,7 @@ fn try_cast_literal_to_type( } let mul = match target_type { DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => 1_i128, + DataType::Timestamp(_, _) => 1_i128, DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), other_type => { return Err(DataFusionError::Internal(format!( @@ -319,6 +321,7 @@ fn try_cast_literal_to_type( DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), + DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128), DataType::Decimal128(precision, _) => ( // Different precision for decimal128 can store different range of value. // For example, the precision is 3, the max of value is `999` and the min @@ -338,6 +341,10 @@ fn try_cast_literal_to_type( ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul), ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul), ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul), ScalarValue::Decimal128(Some(v), _, scale) => { let lit_scale_mul = 10_i128.pow(*scale as u32); if mul >= lit_scale_mul { @@ -376,6 +383,18 @@ fn try_cast_literal_to_type( DataType::Int16 => ScalarValue::Int16(Some(value as i16)), DataType::Int32 => ScalarValue::Int32(Some(value as i32)), DataType::Int64 => ScalarValue::Int64(Some(value as i64)), + DataType::Timestamp(TimeUnit::Second, tz) => { + ScalarValue::TimestampSecond(Some(value as i64), tz.clone()) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => { + ScalarValue::TimestampMillisecond(Some(value as i64), tz.clone()) + } + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + ScalarValue::TimestampMicrosecond(Some(value as i64), tz.clone()) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + ScalarValue::TimestampNanosecond(Some(value as i64), tz.clone()) + } DataType::Decimal128(p, s) => { ScalarValue::Decimal128(Some(value), *p, *s) } @@ -629,6 +648,18 @@ mod tests { assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); } + #[test] + /// Basic integration test for unwrapping casts with different timezones + fn test_unwrap_cast_with_timestamp_nanos() { + let schema = expr_test_schema(); + // cast(ts_nano as Timestamp(Nanosecond, UTC)) < 1666612093000000000::Timestamp(Nanosecond, Utc)) + let expr_lt = try_cast(col("ts_nano_none"), timestamp_nano_utc_type()) + .lt(lit_timestamp_nano_utc(1666612093000000000)); + let expected = + col("ts_nano_none").lt(lit_timestamp_nano_none(1666612093000000000)); + assert_eq!(optimize_test(expr_lt, &schema), expected); + } + fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { let mut expr_rewriter = UnwrapCastExprRewriter { schema: schema.clone(), @@ -646,6 +677,8 @@ mod tests { DFField::new(None, "c4", DataType::Decimal128(38, 37), false), DFField::new(None, "c5", DataType::Float32, false), DFField::new(None, "c6", DataType::UInt32, false), + DFField::new(None, "ts_nano_none", timestamp_nano_none_type(), false), + DFField::new(None, "ts_nano_utf", timestamp_nano_utc_type(), false), ], HashMap::new(), ) @@ -669,10 +702,29 @@ mod tests { lit(ScalarValue::Decimal128(Some(value), precision, scale)) } + fn lit_timestamp_nano_none(ts: i64) -> Expr { + lit(ScalarValue::TimestampNanosecond(Some(ts), None)) + } + + fn lit_timestamp_nano_utc(ts: i64) -> Expr { + let utc = Some("+0:00".to_string()); + lit(ScalarValue::TimestampNanosecond(Some(ts), utc)) + } + fn null_decimal(precision: u8, scale: u8) -> Expr { lit(ScalarValue::Decimal128(None, precision, scale)) } + fn timestamp_nano_none_type() -> DataType { + DataType::Timestamp(TimeUnit::Nanosecond, None) + } + + // this is the type that now() returns + fn timestamp_nano_utc_type() -> DataType { + let utc = Some("+0:00".to_string()); + DataType::Timestamp(TimeUnit::Nanosecond, utc) + } + #[test] fn test_try_cast_to_type_nulls() { // test values that can be cast to/from all integer types @@ -783,6 +835,105 @@ mod tests { ); } + #[test] + fn test_try_cast_to_type_timestamps() { + for time_unit in [ + TimeUnit::Second, + TimeUnit::Millisecond, + TimeUnit::Microsecond, + TimeUnit::Nanosecond, + ] { + let utc = Some("+0:00".to_string()); + // No timezone, utc timezone + let (lit_tz_none, lit_tz_utc) = match time_unit { + TimeUnit::Second => ( + ScalarValue::TimestampSecond(Some(12345), None), + ScalarValue::TimestampSecond(Some(12345), utc), + ), + + TimeUnit::Millisecond => ( + ScalarValue::TimestampMillisecond(Some(12345), None), + ScalarValue::TimestampMillisecond(Some(12345), utc), + ), + + TimeUnit::Microsecond => ( + ScalarValue::TimestampMicrosecond(Some(12345), None), + ScalarValue::TimestampMicrosecond(Some(12345), utc), + ), + + TimeUnit::Nanosecond => ( + ScalarValue::TimestampNanosecond(Some(12345), None), + ScalarValue::TimestampNanosecond(Some(12345), utc), + ), + }; + + // Note that datafusion ignores timezones for comparisons + assert_eq!(lit_tz_none, lit_tz_utc); + + // e.g. DataType::Timestamp(_, None) + let dt_tz_none = lit_tz_none.get_datatype(); + + // e.g. DataType::Timestamp(_, Some(utc)) + let dt_tz_utc = lit_tz_utc.get_datatype(); + + // None <--> None + expect_cast( + lit_tz_none.clone(), + dt_tz_none.clone(), + ExpectedCast::Value(lit_tz_none.clone()), + ); + + // None <--> Utc + expect_cast( + lit_tz_none.clone(), + dt_tz_utc.clone(), + ExpectedCast::Value(lit_tz_utc.clone()), + ); + + // Utc <--> None + expect_cast( + lit_tz_utc.clone(), + dt_tz_none.clone(), + ExpectedCast::Value(lit_tz_none.clone()), + ); + + // Utc <--> Utc + expect_cast( + lit_tz_utc.clone(), + dt_tz_utc.clone(), + ExpectedCast::Value(lit_tz_utc.clone()), + ); + + // timestamp to int64 + expect_cast( + lit_tz_utc.clone(), + DataType::Int64, + ExpectedCast::Value(ScalarValue::Int64(Some(12345))), + ); + + // int64 to timestamp + expect_cast( + ScalarValue::Int64(Some(12345)), + dt_tz_none.clone(), + ExpectedCast::Value(lit_tz_none.clone()), + ); + + // int64 to timestamp + expect_cast( + ScalarValue::Int64(Some(12345)), + dt_tz_utc.clone(), + ExpectedCast::Value(lit_tz_utc.clone()), + ); + + // timestamp to string (not supported yet) + expect_cast( + lit_tz_utc.clone(), + DataType::LargeUtf8, + ExpectedCast::NoValue, + ); + } + } + #[test] fn test_try_cast_to_type_unsupported() { // int64 to list diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index be62ba2a579e6..48cd831bdc8ba 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -236,7 +236,8 @@ fn timestamp_nano_ts_none_predicates() -> Result<()> { // constant and compared to the column without a cast so it can be // pushed down / pruned let expected = - "Projection: test.col_int32\n Filter: CAST(test.col_ts_nano_none AS Timestamp(Nanosecond, Some(\"+00:00\"))) < TimestampNanosecond(1666612093000000000, Some(\"+00:00\"))\ + "Projection: test.col_int32\ + \n Filter: test.col_ts_nano_none < TimestampNanosecond(1666612093000000000, None)\ \n TableScan: test projection=[col_int32, col_ts_nano_none]"; assert_eq!(expected, format!("{:?}", plan)); Ok(()) From 8b67ded63dfd0a0e7c33adc3b1129dbb50bdb39d Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 9 Nov 2022 13:30:30 -0500 Subject: [PATCH 2/3] correct comment in test --- datafusion/optimizer/src/unwrap_cast_in_comparison.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index bcce98ecf7364..581ecdb43b5a7 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -727,7 +727,7 @@ mod tests { #[test] fn test_try_cast_to_type_nulls() { - // test values that can be cast to/from all integer types + // test that nulls can be cast to/from all integer types let scalars = vec![ ScalarValue::Int8(None), ScalarValue::Int16(None), From 75484ed9b5f1244ce046baa1efc66bf382a8ec8f Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 9 Nov 2022 13:34:16 -0500 Subject: [PATCH 3/3] Update datafusion/optimizer/src/unwrap_cast_in_comparison.rs --- datafusion/optimizer/src/unwrap_cast_in_comparison.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 581ecdb43b5a7..28b0856848dac 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -867,7 +867,8 @@ mod tests { ), }; - // Note that datafusion ignores timezones for comparisons + // Datafusion ignores timezones for comparisons of ScalarValue + // so double check it here assert_eq!(lit_tz_none, lit_tz_utc); // e.g. DataType::Timestamp(_, None)