diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index 37180bff5a775..a89cfd8d156bc 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -22,28 +22,18 @@ //! Unicode expressions use arrow::{ - array::{ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait, PrimitiveArray}, + array::{ArrayRef, GenericStringArray, OffsetSizeTrait, PrimitiveArray}, datatypes::{ArrowNativeType, ArrowPrimitiveType}, }; -use datafusion_common::{cast::as_generic_string_array, DataFusionError, Result}; +use datafusion_common::{ + cast::{as_generic_string_array, as_int64_array}, + DataFusionError, Result, +}; use hashbrown::HashMap; -use std::cmp::Ordering; +use std::cmp::{max, Ordering}; use std::sync::Arc; -use std::{any::type_name, cmp::max}; use unicode_segmentation::UnicodeSegmentation; -macro_rules! downcast_arg { - ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ - $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast {} to {}", - $NAME, - type_name::<$ARRAY_TYPE>() - )) - })? - }}; -} - /// Returns number of characters in the string. /// character_length('josé') = 4 /// The implementation counts UTF-8 code points to count the number of characters @@ -72,7 +62,7 @@ where /// The implementation uses UTF-8 code points as characters pub fn left(args: &[ArrayRef]) -> Result { let string_array = as_generic_string_array::(&args[0])?; - let n_array = downcast_arg!(args[1], "n", Int64Array); + let n_array = as_int64_array(&args[1])?; let result = string_array .iter() .zip(n_array.iter()) @@ -104,7 +94,7 @@ pub fn lpad(args: &[ArrayRef]) -> Result { match args.len() { 2 => { let string_array = as_generic_string_array::(&args[0])?; - let length_array = downcast_arg!(args[1], "length", Int64Array); + let length_array = as_int64_array(&args[1])?; let result = string_array .iter() @@ -140,7 +130,7 @@ pub fn lpad(args: &[ArrayRef]) -> Result { } 3 => { let string_array = as_generic_string_array::(&args[0])?; - let length_array = downcast_arg!(args[1], "length", Int64Array); + let length_array = as_int64_array(&args[1])?; let fill_array = as_generic_string_array::(&args[2])?; let result = string_array @@ -216,7 +206,7 @@ pub fn reverse(args: &[ArrayRef]) -> Result { /// The implementation uses UTF-8 code points as characters pub fn right(args: &[ArrayRef]) -> Result { let string_array = as_generic_string_array::(&args[0])?; - let n_array = downcast_arg!(args[1], "n", Int64Array); + let n_array = as_int64_array(&args[1])?; let result = string_array .iter() @@ -250,7 +240,7 @@ pub fn rpad(args: &[ArrayRef]) -> Result { match args.len() { 2 => { let string_array = as_generic_string_array::(&args[0])?; - let length_array = downcast_arg!(args[1], "length", Int64Array); + let length_array = as_int64_array(&args[1])?; let result = string_array .iter() @@ -285,7 +275,7 @@ pub fn rpad(args: &[ArrayRef]) -> Result { } 3 => { let string_array = as_generic_string_array::(&args[0])?; - let length_array = downcast_arg!(args[1], "length", Int64Array); + let length_array = as_int64_array(&args[1])?; let fill_array = as_generic_string_array::(&args[2])?; let result = string_array @@ -376,7 +366,7 @@ pub fn substr(args: &[ArrayRef]) -> Result { match args.len() { 2 => { let string_array = as_generic_string_array::(&args[0])?; - let start_array = downcast_arg!(args[1], "start", Int64Array); + let start_array = as_int64_array(&args[1])?; let result = string_array .iter() @@ -397,8 +387,8 @@ pub fn substr(args: &[ArrayRef]) -> Result { } 3 => { let string_array = as_generic_string_array::(&args[0])?; - let start_array = downcast_arg!(args[1], "start", Int64Array); - let count_array = downcast_arg!(args[2], "count", Int64Array); + let start_array = as_int64_array(&args[1])?; + let count_array = as_int64_array(&args[2])?; let result = string_array .iter()