From 6298c0c71f013c6d1dc11ec26a4973c8ba738687 Mon Sep 17 00:00:00 2001 From: retikulum Date: Mon, 12 Dec 2022 23:21:15 +0300 Subject: [PATCH 1/2] Improve error handling for array downcasting --- .../physical-expr/src/unicode_expressions.rs | 36 +++++++------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index 37180bff5a775..19ccc80cc9cf6 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -22,28 +22,16 @@ //! Unicode expressions use arrow::{ - array::{ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait, PrimitiveArray}, + array::{ArrayRef, GenericStringArray, OffsetSizeTrait, PrimitiveArray}, datatypes::{ArrowNativeType, ArrowPrimitiveType}, }; -use datafusion_common::{cast::as_generic_string_array, DataFusionError, Result}; +use datafusion_common::{cast::{as_int64_array, as_generic_string_array}, DataFusionError, Result}; use hashbrown::HashMap; use std::cmp::Ordering; use std::sync::Arc; -use std::{any::type_name, cmp::max}; +use std::{cmp::max}; use unicode_segmentation::UnicodeSegmentation; -macro_rules! downcast_arg { - ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ - $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast {} to {}", - $NAME, - type_name::<$ARRAY_TYPE>() - )) - })? - }}; -} - /// Returns number of characters in the string. /// character_length('josé') = 4 /// The implementation counts UTF-8 code points to count the number of characters @@ -72,7 +60,7 @@ where /// The implementation uses UTF-8 code points as characters pub fn left(args: &[ArrayRef]) -> Result { let string_array = as_generic_string_array::(&args[0])?; - let n_array = downcast_arg!(args[1], "n", Int64Array); + let n_array = as_int64_array(&args[1])?; let result = string_array .iter() .zip(n_array.iter()) @@ -104,7 +92,7 @@ pub fn lpad(args: &[ArrayRef]) -> Result { match args.len() { 2 => { let string_array = as_generic_string_array::(&args[0])?; - let length_array = downcast_arg!(args[1], "length", Int64Array); + let length_array = as_int64_array(&args[1])?; let result = string_array .iter() @@ -140,7 +128,7 @@ pub fn lpad(args: &[ArrayRef]) -> Result { } 3 => { let string_array = as_generic_string_array::(&args[0])?; - let length_array = downcast_arg!(args[1], "length", Int64Array); + let length_array = as_int64_array(&args[1])?; let fill_array = as_generic_string_array::(&args[2])?; let result = string_array @@ -216,7 +204,7 @@ pub fn reverse(args: &[ArrayRef]) -> Result { /// The implementation uses UTF-8 code points as characters pub fn right(args: &[ArrayRef]) -> Result { let string_array = as_generic_string_array::(&args[0])?; - let n_array = downcast_arg!(args[1], "n", Int64Array); + let n_array = as_int64_array(&args[1])?; let result = string_array .iter() @@ -250,7 +238,7 @@ pub fn rpad(args: &[ArrayRef]) -> Result { match args.len() { 2 => { let string_array = as_generic_string_array::(&args[0])?; - let length_array = downcast_arg!(args[1], "length", Int64Array); + let length_array = as_int64_array(&args[1])?; let result = string_array .iter() @@ -285,7 +273,7 @@ pub fn rpad(args: &[ArrayRef]) -> Result { } 3 => { let string_array = as_generic_string_array::(&args[0])?; - let length_array = downcast_arg!(args[1], "length", Int64Array); + let length_array = as_int64_array(&args[1])?; let fill_array = as_generic_string_array::(&args[2])?; let result = string_array @@ -376,7 +364,7 @@ pub fn substr(args: &[ArrayRef]) -> Result { match args.len() { 2 => { let string_array = as_generic_string_array::(&args[0])?; - let start_array = downcast_arg!(args[1], "start", Int64Array); + let start_array = as_int64_array(&args[1])?; let result = string_array .iter() @@ -397,8 +385,8 @@ pub fn substr(args: &[ArrayRef]) -> Result { } 3 => { let string_array = as_generic_string_array::(&args[0])?; - let start_array = downcast_arg!(args[1], "start", Int64Array); - let count_array = downcast_arg!(args[2], "count", Int64Array); + let start_array = as_int64_array(&args[1])?; + let count_array = as_int64_array(&args[2])?; let result = string_array .iter() From 668aa3dc3ad6e2535ddfd74b4adccedc13ace7c3 Mon Sep 17 00:00:00 2001 From: retikulum Date: Mon, 12 Dec 2022 23:30:48 +0300 Subject: [PATCH 2/2] fix formatting --- datafusion/physical-expr/src/unicode_expressions.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index 19ccc80cc9cf6..a89cfd8d156bc 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -25,11 +25,13 @@ use arrow::{ array::{ArrayRef, GenericStringArray, OffsetSizeTrait, PrimitiveArray}, datatypes::{ArrowNativeType, ArrowPrimitiveType}, }; -use datafusion_common::{cast::{as_int64_array, as_generic_string_array}, DataFusionError, Result}; +use datafusion_common::{ + cast::{as_generic_string_array, as_int64_array}, + DataFusionError, Result, +}; use hashbrown::HashMap; -use std::cmp::Ordering; +use std::cmp::{max, Ordering}; use std::sync::Arc; -use std::{cmp::max}; use unicode_segmentation::UnicodeSegmentation; /// Returns number of characters in the string.