From e8c7a74bc6ee6a37b25cfb9f87e471a8e79df1f3 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Fri, 22 Dec 2023 09:47:39 +0100 Subject: [PATCH] add arguments length check --- .../physical-expr/src/array_expressions.rs | 110 +++++++++++++++++- 1 file changed, 107 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index bdab65cab9e33..fe67b0b79cce1 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -444,6 +444,10 @@ where /// For example: /// > array_element(\[1, 2, 3], 2) -> 2 pub fn array_element(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_element needs two arguments"); + } + match &args[0].data_type() { DataType::List(_) => { let array = as_list_array(&args[0])?; @@ -557,6 +561,10 @@ pub fn array_except(args: &[ArrayRef]) -> Result { /// /// See test cases in `array.slt` for more details. pub fn array_slice(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_slice needs three arguments"); + } + let array_data_type = args[0].data_type(); match array_data_type { DataType::List(_) => { @@ -708,6 +716,10 @@ where /// array_pop_back SQL function pub fn array_pop_back(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_pop_back needs one argument"); + } + let list_array = as_list_array(&args[0])?; let from_array = Int64Array::from(vec![1; list_array.len()]); let to_array = Int64Array::from( @@ -857,6 +869,10 @@ pub fn array_pop_front(args: &[ArrayRef]) -> Result { /// Array_append SQL function pub fn array_append(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_append expects two arguments"); + } + let list_array = as_list_array(&args[0])?; let element_array = &args[1]; @@ -883,6 +899,10 @@ pub fn array_append(args: &[ArrayRef]) -> Result { /// Array_sort SQL function pub fn array_sort(args: &[ArrayRef]) -> Result { + if args.is_empty() || args.len() > 3 { + return exec_err!("array_sort expects one to three arguments"); + } + let sort_option = match args.len() { 1 => None, 2 => { @@ -962,6 +982,10 @@ fn order_nulls_first(modifier: &str) -> Result { /// Array_prepend SQL function pub fn array_prepend(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_prepend expects two arguments"); + } + let list_array = as_list_array(&args[1])?; let element_array = &args[0]; @@ -1082,6 +1106,10 @@ fn concat_internal(args: &[ArrayRef]) -> Result { /// Array_concat/Array_cat SQL function pub fn array_concat(args: &[ArrayRef]) -> Result { + if args.is_empty() { + return exec_err!("array_concat expects at least one arguments"); + } + let mut new_args = vec![]; for arg in args { let ndim = list_ndims(arg.data_type()); @@ -1098,6 +1126,10 @@ pub fn array_concat(args: &[ArrayRef]) -> Result { /// Array_empty SQL function pub fn array_empty(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_empty expects one argument"); + } + if as_null_array(&args[0]).is_ok() { // Make sure to return Boolean type. return Ok(Arc::new(BooleanArray::new_null(args[0].len()))); @@ -1122,6 +1154,10 @@ fn array_empty_dispatch(array: &ArrayRef) -> Result Result { + if args.len() != 2 { + return exec_err!("array_repeat expects two arguments"); + } + let element = &args[0]; let count_array = as_int64_array(&args[1])?; @@ -1257,6 +1293,10 @@ fn general_list_repeat( /// Array_position SQL function pub fn array_position(args: &[ArrayRef]) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!("array_position expects two or three arguments"); + } + let list_array = as_list_array(&args[0])?; let element_array = &args[1]; @@ -1321,6 +1361,10 @@ fn general_position( /// Array_positions SQL function pub fn array_positions(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_positions expects two arguments"); + } + let element = &args[1]; match &args[0].data_type() { @@ -1480,16 +1524,28 @@ fn array_remove_internal( } pub fn array_remove_all(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_remove_all expects two arguments"); + } + let arr_n = vec![i64::MAX; args[0].len()]; array_remove_internal(&args[0], &args[1], arr_n) } pub fn array_remove(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_remove expects two arguments"); + } + let arr_n = vec![1; args[0].len()]; array_remove_internal(&args[0], &args[1], arr_n) } pub fn array_remove_n(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_remove_n expects three arguments"); + } + let arr_n = as_int64_array(&args[2])?.values().to_vec(); array_remove_internal(&args[0], &args[1], arr_n) } @@ -1593,18 +1649,30 @@ fn general_replace( } pub fn array_replace(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_replace expects three arguments"); + } + // replace at most one occurence for each element let arr_n = vec![1; args[0].len()]; general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) } pub fn array_replace_n(args: &[ArrayRef]) -> Result { + if args.len() != 4 { + return exec_err!("array_replace_n expects four arguments"); + } + // replace the specified number of occurences let arr_n = as_int64_array(&args[3])?.values().to_vec(); general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) } pub fn array_replace_all(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_replace_all expects three arguments"); + } + // replace all occurrences (up to "i64::MAX") let arr_n = vec![i64::MAX; args[0].len()]; general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) @@ -1682,7 +1750,7 @@ fn union_generic_lists( /// Array_union SQL function pub fn array_union(args: &[ArrayRef]) -> Result { if args.len() != 2 { - return exec_err!("array_union needs two arguments"); + return exec_err!("array_union needs 2 arguments"); } let array1 = &args[0]; let array2 = &args[1]; @@ -1724,6 +1792,10 @@ pub fn array_union(args: &[ArrayRef]) -> Result { /// Array_to_string SQL function pub fn array_to_string(args: &[ArrayRef]) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!("array_to_string expects two or three arguments"); + } + let arr = &args[0]; let delimiters = as_string_array(&args[1])?; @@ -1833,6 +1905,10 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { /// Cardinality SQL function pub fn cardinality(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("cardinality expects one argument"); + } + let list_array = as_list_array(&args[0])?.clone(); let result = list_array @@ -1889,6 +1965,10 @@ fn flatten_internal( /// Flatten SQL function pub fn flatten(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("flatten expects one argument"); + } + let flattened_array = flatten_internal(&args[0], None)?; Ok(Arc::new(flattened_array) as ArrayRef) } @@ -1913,6 +1993,10 @@ fn array_length_dispatch(array: &[ArrayRef]) -> Result Result { + if args.len() != 1 && args.len() != 2 { + return exec_err!("array_length expects one or two arguments"); + } + match &args[0].data_type() { DataType::List(_) => array_length_dispatch::(args), DataType::LargeList(_) => array_length_dispatch::(args), @@ -1959,6 +2043,10 @@ pub fn array_dims(args: &[ArrayRef]) -> Result { /// Array_ndims SQL function pub fn array_ndims(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_ndims needs one argument"); + } + if let Some(list_array) = args[0].as_list_opt::() { let ndims = datafusion_common::utils::list_ndims(list_array.data_type()); @@ -2049,6 +2137,10 @@ fn general_array_has_dispatch( /// Array_has SQL function pub fn array_has(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_has needs two arguments"); + } + let array_type = args[0].data_type(); match array_type { @@ -2064,6 +2156,10 @@ pub fn array_has(args: &[ArrayRef]) -> Result { /// Array_has_any SQL function pub fn array_has_any(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_has_any needs two arguments"); + } + let array_type = args[0].data_type(); match array_type { @@ -2079,6 +2175,10 @@ pub fn array_has_any(args: &[ArrayRef]) -> Result { /// Array_has_all SQL function pub fn array_has_all(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_has_all needs two arguments"); + } + let array_type = args[0].data_type(); match array_type { @@ -2183,7 +2283,9 @@ pub fn string_to_array(args: &[ArrayRef]) -> Result Result { - assert_eq!(args.len(), 2); + if args.len() != 2 { + return exec_err!("array_intersect needs two arguments"); + } let first_array = &args[0]; let second_array = &args[1]; @@ -2286,7 +2388,9 @@ pub fn general_array_distinct( /// array_distinct SQL function /// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4] pub fn array_distinct(args: &[ArrayRef]) -> Result { - assert_eq!(args.len(), 1); + if args.len() != 1 { + return exec_err!("array_distinct needs one argument"); + } // handle null if args[0].data_type() == &DataType::Null {