diff --git a/datafusion/core/tests/sql/decimal.rs b/datafusion/core/tests/sql/decimal.rs index db686deb70709..0898be62c77f3 100644 --- a/datafusion/core/tests/sql/decimal.rs +++ b/datafusion/core/tests/sql/decimal.rs @@ -811,3 +811,29 @@ async fn sql_abs_decimal() -> Result<()> { assert_batches_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn decimal_null_scalar_array_comparison() -> Result<()> { + let ctx = SessionContext::new(); + let sql = "select a < null from (values (1.1::decimal)) as t(a)"; + let actual = execute_to_batches(&ctx, sql).await; + assert_eq!(1, actual.len()); + assert_eq!(1, actual[0].num_columns()); + assert_eq!(1, actual[0].num_rows()); + assert!(actual[0].column(0).is_null(0)); + assert_eq!(&DataType::Boolean, actual[0].column(0).data_type()); + Ok(()) +} + +#[tokio::test] +async fn decimal_null_array_scalar_comparison() -> Result<()> { + let ctx = SessionContext::new(); + let sql = "select null <= a from (values (1.1::decimal)) as t(a);"; + let actual = execute_to_batches(&ctx, sql).await; + assert_eq!(1, actual.len()); + assert_eq!(1, actual[0].num_columns()); + assert_eq!(1, actual[0].num_rows()); + assert!(actual[0].column(0).is_null(0)); + assert_eq!(&DataType::Boolean, actual[0].column(0).data_type()); + Ok(()) +} diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index b078da5f0b342..b3704dc70307a 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -120,13 +120,33 @@ impl std::fmt::Display for BinaryExpr { } } +macro_rules! compute_decimal_op_dyn_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ + let ll = $LEFT.as_any().downcast_ref::().unwrap(); + if let ScalarValue::Decimal128(Some(_), _, _) = $RIGHT { + Ok(Arc::new(paste::expr! {[<$OP _decimal_scalar>]}( + ll, + $RIGHT.try_into()?, + )?)) + } else { + // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE type + Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) + } + }}; +} + macro_rules! compute_decimal_op_scalar { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - let ll = $LEFT.as_any().downcast_ref::<$DT>().unwrap(); - Ok(Arc::new(paste::expr! {[<$OP _decimal_scalar>]}( - ll, - $RIGHT.try_into()?, - )?)) + let ll = $LEFT.as_any().downcast_ref::().unwrap(); + if let ScalarValue::Decimal128(Some(_), _, _) = $RIGHT { + Ok(Arc::new(paste::expr! {[<$OP _decimal_scalar>]}( + ll, + $RIGHT.try_into()?, + )?)) + } else { + // when the $RIGHT is a NULL, generate a NULL array of LEFT's datatype + Ok(Arc::new(new_null_array($LEFT.data_type(), $LEFT.len()))) + } }}; } @@ -642,7 +662,7 @@ macro_rules! binary_array_op_dyn_scalar { let result: Result> = match right { ScalarValue::Boolean(b) => compute_bool_op_dyn_scalar!($LEFT, b, $OP, $OP_TYPE), - ScalarValue::Decimal128(..) => compute_decimal_op_scalar!($LEFT, right, $OP, Decimal128Array), + ScalarValue::Decimal128(..) => compute_decimal_op_dyn_scalar!($LEFT, right, $OP, $OP_TYPE), ScalarValue::Utf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), ScalarValue::LargeUtf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), ScalarValue::Binary(v) => compute_binary_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),