diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 22e3b0358ce55..f5dbd966132be 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -1649,3 +1649,68 @@ async fn binary_mathematical_operator_with_null_lt() { assert!(batch.columns()[0].is_null(1)); } } + +#[tokio::test] +async fn query_binary_eq() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Binary, true), + Field::new("c2", DataType::LargeBinary, true), + Field::new("c3", DataType::Binary, true), + Field::new("c4", DataType::LargeBinary, true), + ])); + + let c1 = BinaryArray::from_opt_vec(vec![ + Some(b"one"), + Some(b"two"), + None, + Some(b""), + Some(b"three"), + ]); + let c2 = LargeBinaryArray::from_opt_vec(vec![ + Some(b"one"), + Some(b"two"), + None, + Some(b""), + Some(b"three"), + ]); + let c3 = BinaryArray::from_opt_vec(vec![ + Some(b"one"), + Some(b""), + None, + Some(b""), + Some(b"three"), + ]); + let c4 = LargeBinaryArray::from_opt_vec(vec![ + Some(b"one"), + Some(b"two"), + None, + Some(b""), + Some(b""), + ]); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let ctx = SessionContext::new(); + + ctx.register_table("test", Arc::new(table))?; + + let sql = " + SELECT sha256(c1)=digest('one', 'sha256'), sha256(c2)=sha256('two'), digest(c1, 'blake2b')=digest(c3, 'blake2b'), c2=c4 + FROM test + "; + let actual = execute(&ctx, sql).await; + let expected = vec![ + vec!["true", "false", "true", "true"], + vec!["false", "true", "false", "true"], + vec!["NULL", "NULL", "NULL", "NULL"], + vec!["false", "false", "true", "true"], + vec!["false", "false", "true", "false"], + ]; + assert_eq!(expected, actual); + Ok(()) +} diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index adae10e224cd7..8b0f1646608dc 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -334,7 +334,7 @@ scalar_expr!(BitLength, bit_length, string); scalar_expr!(CharacterLength, character_length, string); scalar_expr!(CharacterLength, length, string); scalar_expr!(Chr, chr, string); -scalar_expr!(Digest, digest, string, algorithm); +scalar_expr!(Digest, digest, input, algorithm); scalar_expr!(InitCap, initcap, string); scalar_expr!(Left, left, string, count); scalar_expr!(Lower, lower, string); diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 7851edb137122..b8b17bf4003ea 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -73,7 +73,23 @@ macro_rules! make_utf8_to_return_type { make_utf8_to_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); make_utf8_to_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); -make_utf8_to_return_type!(utf8_to_binary_type, DataType::Binary, DataType::Binary); + +fn utf8_or_binary_to_binary_type(arg_type: &DataType, name: &str) -> Result { + Ok(match arg_type { + DataType::LargeUtf8 + | DataType::Utf8 + | DataType::Binary + | DataType::LargeBinary => DataType::Binary, + DataType::Null => DataType::Null, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal(format!( + "The {:?} function can only accept strings or binary arrays.", + name + ))); + } + }) +} /// Returns the datatype of the scalar function pub fn return_type( @@ -154,19 +170,19 @@ pub fn return_type( BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"), BuiltinScalarFunction::Rtrim => utf8_to_str_type(&input_expr_types[0], "rtrimp"), BuiltinScalarFunction::SHA224 => { - utf8_to_binary_type(&input_expr_types[0], "sha224") + utf8_or_binary_to_binary_type(&input_expr_types[0], "sha224") } BuiltinScalarFunction::SHA256 => { - utf8_to_binary_type(&input_expr_types[0], "sha256") + utf8_or_binary_to_binary_type(&input_expr_types[0], "sha256") } BuiltinScalarFunction::SHA384 => { - utf8_to_binary_type(&input_expr_types[0], "sha384") + utf8_or_binary_to_binary_type(&input_expr_types[0], "sha384") } BuiltinScalarFunction::SHA512 => { - utf8_to_binary_type(&input_expr_types[0], "sha512") + utf8_or_binary_to_binary_type(&input_expr_types[0], "sha512") } BuiltinScalarFunction::Digest => { - utf8_to_binary_type(&input_expr_types[0], "digest") + utf8_or_binary_to_binary_type(&input_expr_types[0], "digest") } BuiltinScalarFunction::SplitPart => { utf8_to_str_type(&input_expr_types[0], "split_part") @@ -284,18 +300,27 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { conditional_expressions::SUPPORTED_COALESCE_TYPES.to_vec(), fun.volatility(), ), + BuiltinScalarFunction::SHA224 + | BuiltinScalarFunction::SHA256 + | BuiltinScalarFunction::SHA384 + | BuiltinScalarFunction::SHA512 + | BuiltinScalarFunction::MD5 => Signature::uniform( + 1, + vec![ + DataType::Utf8, + DataType::LargeUtf8, + DataType::Binary, + DataType::LargeBinary, + ], + fun.volatility(), + ), BuiltinScalarFunction::Ascii | BuiltinScalarFunction::BitLength | BuiltinScalarFunction::CharacterLength | BuiltinScalarFunction::InitCap | BuiltinScalarFunction::Lower - | BuiltinScalarFunction::MD5 | BuiltinScalarFunction::OctetLength | BuiltinScalarFunction::Reverse - | BuiltinScalarFunction::SHA224 - | BuiltinScalarFunction::SHA256 - | BuiltinScalarFunction::SHA384 - | BuiltinScalarFunction::SHA512 | BuiltinScalarFunction::Upper => Signature::uniform( 1, vec![DataType::Utf8, DataType::LargeUtf8], @@ -401,9 +426,15 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::FromUnixtime => { Signature::uniform(1, vec![DataType::Int64], fun.volatility()) } - BuiltinScalarFunction::Digest => { - Signature::exact(vec![DataType::Utf8, DataType::Utf8], fun.volatility()) - } + BuiltinScalarFunction::Digest => Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Binary, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::LargeBinary, DataType::Utf8]), + ], + fun.volatility(), + ), BuiltinScalarFunction::DateTrunc => Signature::exact( vec![ DataType::Utf8, diff --git a/datafusion/physical-expr/src/crypto_expressions.rs b/datafusion/physical-expr/src/crypto_expressions.rs index e0314317c25f1..85f3ebdb5cada 100644 --- a/datafusion/physical-expr/src/crypto_expressions.rs +++ b/datafusion/physical-expr/src/crypto_expressions.rs @@ -19,7 +19,8 @@ use arrow::{ array::{ - Array, ArrayRef, BinaryArray, GenericStringArray, OffsetSizeTrait, StringArray, + Array, ArrayRef, BinaryArray, GenericBinaryArray, GenericStringArray, + OffsetSizeTrait, StringArray, }, datatypes::DataType, }; @@ -59,17 +60,22 @@ fn digest_process( ) -> Result { match value { ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 => digest_algorithm.digest_array::(a.as_ref()), - DataType::LargeUtf8 => digest_algorithm.digest_array::(a.as_ref()), + DataType::Utf8 => digest_algorithm.digest_utf8_array::(a.as_ref()), + DataType::LargeUtf8 => digest_algorithm.digest_utf8_array::(a.as_ref()), + DataType::Binary => digest_algorithm.digest_binary_array::(a.as_ref()), + DataType::LargeBinary => { + digest_algorithm.digest_binary_array::(a.as_ref()) + } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function {}", other, digest_algorithm, ))), }, ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { - Ok(digest_algorithm.digest_scalar(a)) - } + ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => Ok(digest_algorithm + .digest_scalar(&a.as_ref().map(|s: &String| s.as_bytes()))), + ScalarValue::Binary(a) | ScalarValue::LargeBinary(a) => Ok(digest_algorithm + .digest_scalar(&a.as_ref().map(|v: &Vec| v.as_slice()))), other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function {}", other, digest_algorithm, @@ -106,7 +112,7 @@ macro_rules! digest_to_scalar { impl DigestAlgorithm { /// digest an optional string to its hash value, null values are returned as is - fn digest_scalar(self, value: &Option) -> ColumnarValue { + fn digest_scalar(self, value: &Option<&[u8]>) -> ColumnarValue { ColumnarValue::Scalar(match self { Self::Md5 => digest_to_scalar!(Md5, value), Self::Sha224 => digest_to_scalar!(Sha224, value), @@ -115,16 +121,55 @@ impl DigestAlgorithm { Self::Sha512 => digest_to_scalar!(Sha512, value), Self::Blake2b => digest_to_scalar!(Blake2b512, value), Self::Blake2s => digest_to_scalar!(Blake2s256, value), - Self::Blake3 => ScalarValue::Binary(value.as_ref().map(|v| { + Self::Blake3 => ScalarValue::Binary(value.map(|v| { let mut digest = Blake3::default(); - digest.update(v.as_bytes()); + digest.update(v); digest.finalize().as_bytes().to_vec() })), }) } + /// digest a binary array to their hash values + fn digest_binary_array(self, value: &dyn Array) -> Result + where + T: OffsetSizeTrait, + { + let input_value = value + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast value to {}", + type_name::>() + )) + })?; + let array: ArrayRef = match self { + Self::Md5 => digest_to_array!(Md5, input_value), + Self::Sha224 => digest_to_array!(Sha224, input_value), + Self::Sha256 => digest_to_array!(Sha256, input_value), + Self::Sha384 => digest_to_array!(Sha384, input_value), + Self::Sha512 => digest_to_array!(Sha512, input_value), + Self::Blake2b => digest_to_array!(Blake2b512, input_value), + Self::Blake2s => digest_to_array!(Blake2s256, input_value), + Self::Blake3 => { + let binary_array: BinaryArray = input_value + .iter() + .map(|opt| { + opt.map(|x| { + let mut digest = Blake3::default(); + digest.update(x); + digest.finalize().as_bytes().to_vec() + }) + }) + .collect(); + Arc::new(binary_array) + } + }; + Ok(ColumnarValue::Array(array)) + } + /// digest a string array to their hash values - fn digest_array(self, value: &dyn Array) -> Result + fn digest_utf8_array(self, value: &dyn Array) -> Result where T: OffsetSizeTrait, { diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 6769032bff6de..8323aa5abbe69 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -28,6 +28,10 @@ use arrow::compute::kernels::arithmetic::{ multiply_scalar, subtract, subtract_scalar, }; use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene}; +use arrow::compute::kernels::comparison::{ + eq_dyn_binary_scalar, gt_dyn_binary_scalar, gt_eq_dyn_binary_scalar, + lt_dyn_binary_scalar, lt_eq_dyn_binary_scalar, neq_dyn_binary_scalar, +}; use arrow::compute::kernels::comparison::{ eq_dyn_bool_scalar, gt_dyn_bool_scalar, gt_eq_dyn_bool_scalar, lt_dyn_bool_scalar, lt_eq_dyn_bool_scalar, neq_dyn_bool_scalar, @@ -201,6 +205,21 @@ macro_rules! compute_utf8_op_dyn_scalar { }}; } +/// Invoke a compute kernel on a data array and a scalar value +macro_rules! compute_binary_op_dyn_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ + if let Some(binary_value) = $RIGHT { + Ok(Arc::new(paste::expr! {[<$OP _dyn_binary_scalar>]}( + $LEFT, + &binary_value, + )?)) + } else { + // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE + Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) + } + }}; +} + /// Invoke a compute kernel on a boolean data array and a scalar value macro_rules! compute_bool_op_dyn_scalar { ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ @@ -625,6 +644,8 @@ macro_rules! binary_array_op_dyn_scalar { ScalarValue::Decimal128(..) => compute_decimal_op_scalar!($LEFT, right, $OP, Decimal128Array), ScalarValue::Utf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), ScalarValue::LargeUtf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), + ScalarValue::Binary(v) => compute_binary_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), + ScalarValue::LargeBinary(v) => compute_binary_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), ScalarValue::Int8(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), ScalarValue::Int16(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), ScalarValue::Int32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),