From 3b465c33e9e9655013cfc8529294e28f33f1d602 Mon Sep 17 00:00:00 2001 From: lyne7-sc <734432041@qq.com> Date: Thu, 15 Jan 2026 23:05:19 +0800 Subject: [PATCH 1/4] optimize spark_hex dictionary path --- datafusion/spark/src/function/math/hex.rs | 95 +++++++++++-------- .../test_files/spark/math/hex.slt | 15 +++ 2 files changed, 71 insertions(+), 39 deletions(-) diff --git a/datafusion/spark/src/function/math/hex.rs b/datafusion/spark/src/function/math/hex.rs index 134324f45f5b..77f2bce1377b 100644 --- a/datafusion/spark/src/function/math/hex.rs +++ b/datafusion/spark/src/function/math/hex.rs @@ -19,7 +19,9 @@ use std::any::Any; use std::str::from_utf8_unchecked; use std::sync::Arc; -use arrow::array::{Array, BinaryArray, Int64Array, StringArray, StringBuilder}; +use arrow::array::{ + Array, ArrayRef, StringBuilder, +}; use arrow::datatypes::DataType; use arrow::{ array::{as_dictionary_array, as_largestring_array, as_string_array}, @@ -27,12 +29,12 @@ use arrow::{ }; use datafusion_common::cast::as_large_binary_array; use datafusion_common::cast::as_string_view_array; -use datafusion_common::types::{NativeType, logical_int64, logical_string}; +use datafusion_common::types::{logical_int64, logical_string, NativeType}; use datafusion_common::utils::take_function_args; use datafusion_common::{ - DataFusionError, cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array}, exec_err, + DataFusionError, }; use datafusion_expr::{ Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, @@ -92,11 +94,13 @@ impl ScalarUDFImpl for SparkHex { &self.signature } - fn return_type( - &self, - _arg_types: &[DataType], - ) -> datafusion_common::Result { - Ok(DataType::Utf8) + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + Ok(match &arg_types[0] { + DataType::Dictionary(key_type, _) => { + DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8)) + } + _ => DataType::Utf8, + }) } fn invoke_with_args( @@ -241,29 +245,38 @@ pub fn compute_hex( let array = as_fixed_size_binary_array(array)?; hex_encode_bytes(array.iter(), lowercase, array.len()) } - DataType::Dictionary(_, value_type) => { + DataType::Dictionary(_, _) => { let dict = as_dictionary_array::(&array); + let dict_values = dict.values(); - match **value_type { + let encoded_values: ColumnarValue = match dict_values.data_type() { DataType::Int64 => { - let arr = dict.downcast_dict::().unwrap(); - hex_encode_int64(arr.into_iter(), dict.len()) + let arr = as_int64_array(dict_values)?; + hex_encode_int64(arr.iter(), arr.len())? } DataType::Utf8 => { - let arr = dict.downcast_dict::().unwrap(); - hex_encode_bytes(arr.into_iter(), lowercase, dict.len()) + let arr = as_string_array(dict_values); + hex_encode_bytes(arr.iter(), lowercase, arr.len())? } DataType::Binary => { - let arr = dict.downcast_dict::().unwrap(); - hex_encode_bytes(arr.into_iter(), lowercase, dict.len()) + let arr = as_binary_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? } _ => { - exec_err!( + return exec_err!( "hex got an unexpected argument type: {}", - array.data_type() + dict_values.data_type() ) } - } + }; + + let encoded_values_array: ArrayRef = match encoded_values { + ColumnarValue::Array(a) => a, + ColumnarValue::Scalar(s) => Arc::new(s.to_array()?), + }; + + let new_dict = dict.with_values(encoded_values_array); + Ok(ColumnarValue::Array(Arc::new(new_dict))) } _ => exec_err!("hex got an unexpected argument type: {}", array.data_type()), }, @@ -279,11 +292,12 @@ mod test { use arrow::array::{DictionaryArray, Int32Array, Int64Array, StringArray}; use arrow::{ array::{ - BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringBuilder, - StringDictionaryBuilder, as_string_array, + as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, + StringDictionaryBuilder, }, datatypes::{Int32Type, Int64Type}, }; + use datafusion_common::cast::as_dictionary_array; use datafusion_expr::ColumnarValue; #[test] @@ -295,12 +309,12 @@ mod test { input_builder.append_value("rust"); let input = input_builder.finish(); - let mut string_builder = StringBuilder::new(); - string_builder.append_value("6869"); - string_builder.append_value("627965"); - string_builder.append_null(); - string_builder.append_value("72757374"); - let expected = string_builder.finish(); + let mut expected_builder = StringDictionaryBuilder::::new(); + expected_builder.append_value("6869"); + expected_builder.append_value("627965"); + expected_builder.append_null(); + expected_builder.append_value("72757374"); + let expected = expected_builder.finish(); let columnar_value = ColumnarValue::Array(Arc::new(input)); let result = super::spark_hex(&[columnar_value]).unwrap(); @@ -310,7 +324,7 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } @@ -324,12 +338,12 @@ mod test { input_builder.append_value(3); let input = input_builder.finish(); - let mut string_builder = StringBuilder::new(); - string_builder.append_value("1"); - string_builder.append_value("2"); - string_builder.append_null(); - string_builder.append_value("3"); - let expected = string_builder.finish(); + let mut expected_builder = StringDictionaryBuilder::::new(); + expected_builder.append_value("1"); + expected_builder.append_value("2"); + expected_builder.append_null(); + expected_builder.append_value("3"); + let expected = expected_builder.finish(); let columnar_value = ColumnarValue::Array(Arc::new(input)); let result = super::spark_hex(&[columnar_value]).unwrap(); @@ -339,7 +353,7 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } @@ -353,7 +367,7 @@ mod test { input_builder.append_value("3"); let input = input_builder.finish(); - let mut expected_builder = StringBuilder::new(); + let mut expected_builder = StringDictionaryBuilder::::new(); expected_builder.append_value("31"); expected_builder.append_value("6A"); expected_builder.append_null(); @@ -368,7 +382,7 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } @@ -425,8 +439,11 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); - let expected = StringArray::from(vec![Some("20"), None, None]); + let result = as_dictionary_array(&result).unwrap(); + + let keys = Int32Array::from(vec![Some(0), None, Some(1)]); + let vals = StringArray::from(vec![Some("20"), None]); + let expected = DictionaryArray::new(keys, Arc::new(vals)); assert_eq!(&expected, result); } diff --git a/datafusion/sqllogictest/test_files/spark/math/hex.slt b/datafusion/sqllogictest/test_files/spark/math/hex.slt index 05c9fb3f31b2..756088a26d6c 100644 --- a/datafusion/sqllogictest/test_files/spark/math/hex.slt +++ b/datafusion/sqllogictest/test_files/spark/math/hex.slt @@ -63,3 +63,18 @@ query T SELECT hex(arrow_cast('test', 'LargeBinary')) as lar_b; ---- 74657374 + +statement ok +CREATE TABLE t_dict_utf8 AS +SELECT arrow_cast(column1, 'Dictionary(Int32, Utf8)') as dict_col +FROM VALUES ('foo'), ('bar'), ('foo'), (NULL), ('baz'), ('bar'); + +query T +SELECT hex(dict_col) FROM t_dict_utf8; +---- +666F6F +626172 +666F6F +NULL +62617A +626172 From 8b5087ebd9171c9f71622007c3e29e3564f54516 Mon Sep 17 00:00:00 2001 From: lyne7-sc <734432041@qq.com> Date: Thu, 15 Jan 2026 23:27:13 +0800 Subject: [PATCH 2/4] cargo fmt --- datafusion/spark/src/function/math/hex.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/datafusion/spark/src/function/math/hex.rs b/datafusion/spark/src/function/math/hex.rs index 77f2bce1377b..5d88d8684d49 100644 --- a/datafusion/spark/src/function/math/hex.rs +++ b/datafusion/spark/src/function/math/hex.rs @@ -19,9 +19,7 @@ use std::any::Any; use std::str::from_utf8_unchecked; use std::sync::Arc; -use arrow::array::{ - Array, ArrayRef, StringBuilder, -}; +use arrow::array::{Array, ArrayRef, StringBuilder}; use arrow::datatypes::DataType; use arrow::{ array::{as_dictionary_array, as_largestring_array, as_string_array}, @@ -29,12 +27,12 @@ use arrow::{ }; use datafusion_common::cast::as_large_binary_array; use datafusion_common::cast::as_string_view_array; -use datafusion_common::types::{logical_int64, logical_string, NativeType}; +use datafusion_common::types::{NativeType, logical_int64, logical_string}; use datafusion_common::utils::take_function_args; use datafusion_common::{ + DataFusionError, cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array}, exec_err, - DataFusionError, }; use datafusion_expr::{ Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, @@ -266,7 +264,7 @@ pub fn compute_hex( return exec_err!( "hex got an unexpected argument type: {}", dict_values.data_type() - ) + ); } }; @@ -292,8 +290,8 @@ mod test { use arrow::array::{DictionaryArray, Int32Array, Int64Array, StringArray}; use arrow::{ array::{ - as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, - StringDictionaryBuilder, + BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringDictionaryBuilder, + as_string_array, }, datatypes::{Int32Type, Int64Type}, }; From ba8d1ee17336a67525b4b06175c3d927a23c2f43 Mon Sep 17 00:00:00 2001 From: lyne7-sc <734432041@qq.com> Date: Fri, 16 Jan 2026 23:19:06 +0800 Subject: [PATCH 3/4] refactor hex_encode_bytes/hex_encode_int64 && add datatype for dictionary --- datafusion/spark/src/function/math/hex.rs | 89 +++++++++++++++++------ 1 file changed, 67 insertions(+), 22 deletions(-) diff --git a/datafusion/spark/src/function/math/hex.rs b/datafusion/spark/src/function/math/hex.rs index 5d88d8684d49..06c77f37021b 100644 --- a/datafusion/spark/src/function/math/hex.rs +++ b/datafusion/spark/src/function/math/hex.rs @@ -138,7 +138,7 @@ fn hex_encode_bytes<'a, I, T>( iter: I, lowercase: bool, len: usize, -) -> Result +) -> Result where I: Iterator>, T: AsRef<[u8]> + 'a, @@ -168,14 +168,14 @@ where } } - Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + Ok(Arc::new(builder.finish())) } /// Generic hex encoding for int64 type -fn hex_encode_int64(iter: I, len: usize) -> Result -where - I: Iterator>, -{ +fn hex_encode_int64( + iter: impl Iterator>, + len: usize, +) -> Result { let mut builder = StringBuilder::with_capacity(len, len * 16); for v in iter { @@ -191,7 +191,7 @@ where } } - Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + Ok(Arc::new(builder.finish())) } /// Spark-compatible `hex` function @@ -217,37 +217,71 @@ pub fn compute_hex( ColumnarValue::Array(array) => match array.data_type() { DataType::Int64 => { let array = as_int64_array(array)?; - hex_encode_int64(array.iter(), array.len()) + Ok(ColumnarValue::Array(hex_encode_int64( + array.iter(), + array.len(), + )?)) } DataType::Utf8 => { let array = as_string_array(array); - hex_encode_bytes(array.iter(), lowercase, array.len()) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::Utf8View => { let array = as_string_view_array(array)?; - hex_encode_bytes(array.iter(), lowercase, array.len()) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::LargeUtf8 => { let array = as_largestring_array(array); - hex_encode_bytes(array.iter(), lowercase, array.len()) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::Binary => { let array = as_binary_array(array)?; - hex_encode_bytes(array.iter(), lowercase, array.len()) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::LargeBinary => { let array = as_large_binary_array(array)?; - hex_encode_bytes(array.iter(), lowercase, array.len()) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::FixedSizeBinary(_) => { let array = as_fixed_size_binary_array(array)?; - hex_encode_bytes(array.iter(), lowercase, array.len()) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } - DataType::Dictionary(_, _) => { + DataType::Dictionary(key_type, _) => { + if **key_type != DataType::Int32 { + return exec_err!( + "hex only supports Int32 dictionary keys, get: {}", + key_type + ); + } + let dict = as_dictionary_array::(&array); let dict_values = dict.values(); - let encoded_values: ColumnarValue = match dict_values.data_type() { + let encoded_values = match dict_values.data_type() { DataType::Int64 => { let arr = as_int64_array(dict_values)?; hex_encode_int64(arr.iter(), arr.len())? @@ -256,10 +290,26 @@ pub fn compute_hex( let arr = as_string_array(dict_values); hex_encode_bytes(arr.iter(), lowercase, arr.len())? } + DataType::LargeUtf8 => { + let arr = as_largestring_array(dict_values); + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::Utf8View => { + let arr = as_string_view_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } DataType::Binary => { let arr = as_binary_array(dict_values)?; hex_encode_bytes(arr.iter(), lowercase, arr.len())? } + DataType::LargeBinary => { + let arr = as_large_binary_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::FixedSizeBinary(_) => { + let arr = as_fixed_size_binary_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } _ => { return exec_err!( "hex got an unexpected argument type: {}", @@ -268,12 +318,7 @@ pub fn compute_hex( } }; - let encoded_values_array: ArrayRef = match encoded_values { - ColumnarValue::Array(a) => a, - ColumnarValue::Scalar(s) => Arc::new(s.to_array()?), - }; - - let new_dict = dict.with_values(encoded_values_array); + let new_dict = dict.with_values(encoded_values); Ok(ColumnarValue::Array(Arc::new(new_dict))) } _ => exec_err!("hex got an unexpected argument type: {}", array.data_type()), From dcde0f81b6fa942f11500df14d37353052d683c0 Mon Sep 17 00:00:00 2001 From: lyne7-sc <734432041@qq.com> Date: Sat, 17 Jan 2026 12:42:13 +0800 Subject: [PATCH 4/4] update slt to dictionary(binary) --- datafusion/sqllogictest/test_files/spark/math/hex.slt | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/datafusion/sqllogictest/test_files/spark/math/hex.slt b/datafusion/sqllogictest/test_files/spark/math/hex.slt index 756088a26d6c..17e9ff432890 100644 --- a/datafusion/sqllogictest/test_files/spark/math/hex.slt +++ b/datafusion/sqllogictest/test_files/spark/math/hex.slt @@ -65,12 +65,12 @@ SELECT hex(arrow_cast('test', 'LargeBinary')) as lar_b; 74657374 statement ok -CREATE TABLE t_dict_utf8 AS -SELECT arrow_cast(column1, 'Dictionary(Int32, Utf8)') as dict_col +CREATE TABLE t_dict_binary AS +SELECT arrow_cast(column1, 'Dictionary(Int32, Binary)') as dict_col FROM VALUES ('foo'), ('bar'), ('foo'), (NULL), ('baz'), ('bar'); query T -SELECT hex(dict_col) FROM t_dict_utf8; +SELECT hex(dict_col) FROM t_dict_binary; ---- 666F6F 626172 @@ -78,3 +78,8 @@ SELECT hex(dict_col) FROM t_dict_utf8; NULL 62617A 626172 + +query T +SELECT arrow_typeof(hex(dict_col)) FROM t_dict_binary LIMIT 1; +---- +Dictionary(Int32, Utf8)