diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 1fec19c76b3df..b565feafb5aaf 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -318,6 +318,74 @@ impl DataPtr for Arc { } } +/// Adopted from strsim-rs for string similarity metrics +pub mod datafusion_strsim { + // Source: https://github.com/dguo/strsim-rs/blob/master/src/lib.rs + // License: https://github.com/dguo/strsim-rs/blob/master/LICENSE + use std::cmp::min; + use std::str::Chars; + + struct StringWrapper<'a>(&'a str); + + impl<'a, 'b> IntoIterator for &'a StringWrapper<'b> { + type Item = char; + type IntoIter = Chars<'b>; + + fn into_iter(self) -> Self::IntoIter { + self.0.chars() + } + } + + /// Calculates the minimum number of insertions, deletions, and substitutions + /// required to change one sequence into the other. + fn generic_levenshtein<'a, 'b, Iter1, Iter2, Elem1, Elem2>( + a: &'a Iter1, + b: &'b Iter2, + ) -> usize + where + &'a Iter1: IntoIterator, + &'b Iter2: IntoIterator, + Elem1: PartialEq, + { + let b_len = b.into_iter().count(); + + if a.into_iter().next().is_none() { + return b_len; + } + + let mut cache: Vec = (1..b_len + 1).collect(); + + let mut result = 0; + + for (i, a_elem) in a.into_iter().enumerate() { + result = i + 1; + let mut distance_b = i; + + for (j, b_elem) in b.into_iter().enumerate() { + let cost = if a_elem == b_elem { 0usize } else { 1usize }; + let distance_a = distance_b + cost; + distance_b = cache[j]; + result = min(result + 1, min(distance_a, distance_b + 1)); + cache[j] = result; + } + } + + result + } + + /// Calculates the minimum number of insertions, deletions, and substitutions + /// required to change one string into the other. + /// + /// ``` + /// use datafusion_common::utils::datafusion_strsim::levenshtein; + /// + /// assert_eq!(3, levenshtein("kitten", "sitting")); + /// ``` + pub fn levenshtein(a: &str, b: &str) -> usize { + generic_levenshtein(&StringWrapper(a), &StringWrapper(b)) + } +} + #[cfg(test)] mod tests { use arrow::array::Float64Array; diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index f43d7c87c4c6d..fba38e93aa88a 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -2257,10 +2257,9 @@ mod tests { let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t") .await .unwrap_err(); - assert_eq!( - err.to_string(), - "Error during planning: Invalid function \'my_func\'" - ); + assert!(err + .to_string() + .contains("Error during planning: Invalid function \'my_func\'")); // Can call it if you put quotes let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?; @@ -2304,10 +2303,9 @@ mod tests { let err = plan_and_collect(&ctx, "SELECT MY_AVG(i) FROM t") .await .unwrap_err(); - assert_eq!( - err.to_string(), - "Error during planning: Invalid function \'my_avg\'" - ); + assert!(err + .to_string() + .contains("Error during planning: Invalid function \'my_avg\'")); // Can call it if you put quotes let result = plan_and_collect(&ctx, "SELECT \"MY_AVG\"(i) FROM t").await?; diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 3ff81581c096c..81a0122b8083a 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -363,10 +363,9 @@ async fn case_sensitive_identifiers_aggregates() { let err = plan_and_collect(&ctx, "SELECT \"MAX\"(i) FROM t") .await .unwrap_err(); - assert_eq!( - err.to_string(), - "Error during planning: Invalid function 'MAX'" - ); + assert!(err + .to_string() + .contains("Error during planning: Invalid function 'MAX'")); let results = plan_and_collect(&ctx, "SELECT \"max\"(i) FROM t") .await diff --git a/datafusion/core/tests/sql/functions.rs b/datafusion/core/tests/sql/functions.rs index 6083f8834ca30..cb9827cb7475e 100644 --- a/datafusion/core/tests/sql/functions.rs +++ b/datafusion/core/tests/sql/functions.rs @@ -64,10 +64,9 @@ async fn case_sensitive_identifiers_functions() { let err = plan_and_collect(&ctx, "SELECT \"SQRT\"(i) FROM t") .await .unwrap_err(); - assert_eq!( - err.to_string(), - "Error during planning: Invalid function 'SQRT'" - ); + assert!(err + .to_string() + .contains("Error during planning: Invalid function 'SQRT'")); let results = plan_and_collect(&ctx, "SELECT \"sqrt\"(i) FROM t") .await diff --git a/datafusion/core/tests/sqllogictests/test_files/functions.slt b/datafusion/core/tests/sqllogictests/test_files/functions.slt index c6f2db8eb2332..0cb54711cfbd7 100644 --- a/datafusion/core/tests/sqllogictests/test_files/functions.slt +++ b/datafusion/core/tests/sqllogictests/test_files/functions.slt @@ -418,3 +418,43 @@ SELECT length(c1) FROM test statement ok drop table test + +# +# Testing error message for wrong function name +# + +statement ok +CREATE TABLE test( + v1 Int, + v2 Int +) as VALUES +(1, 10), +(2, 20), +(3, 30); + +# Scalar function +statement error Did you mean 'arrow_typeof'? +SELECT arrowtypeof(v1) from test; + +# Scalar function +statement error Did you mean 'to_timestamp_seconds'? +SELECT to_TIMESTAMPS_second(v2) from test; + +# Aggregate function +statement error Did you mean 'COUNT'? +SELECT counter(*) from test; + +# Aggregate function +statement error Did you mean 'STDDEV'? +SELECT STDEV(v1) from test; + +# Window function +statement error Did you mean 'SUM'? +SELECT v1, v2, SUMM(v2) OVER(ORDER BY v1) from test; + +# Window function +statement error Did you mean 'ROW_NUMBER'? +SELECT v1, v2, ROWNUMBER() OVER(ORDER BY v1) from test; + +statement ok +drop table test diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 8258c8b80585a..5b0676a815099 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -22,9 +22,10 @@ use arrow::datatypes::{DataType, Field}; use datafusion_common::{DataFusionError, Result}; use std::sync::Arc; use std::{fmt, str::FromStr}; +use strum_macros::EnumIter; /// Enum of all built-in aggregate functions -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum AggregateFunction { /// count Count, diff --git a/datafusion/expr/src/function_err.rs b/datafusion/expr/src/function_err.rs index 39ac4ef8039a7..e97e0f92cd80f 100644 --- a/datafusion/expr/src/function_err.rs +++ b/datafusion/expr/src/function_err.rs @@ -29,8 +29,12 @@ //! ``` use crate::function::signature; -use crate::{BuiltinScalarFunction, TypeSignature}; +use crate::{ + AggregateFunction, BuiltInWindowFunction, BuiltinScalarFunction, TypeSignature, +}; use arrow::datatypes::DataType; +use datafusion_common::utils::datafusion_strsim; +use strum::IntoEnumIterator; impl TypeSignature { fn to_string_repr(&self) -> Vec { @@ -89,3 +93,33 @@ pub fn generate_signature_error_msg( fun, join_types(input_expr_types, ", "), candidate_signatures ) } + +/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve) +/// Input `candidates` must not be empty otherwise it will panic +fn find_closest_match(candidates: Vec, target: &str) -> String { + let target = target.to_lowercase(); + candidates + .into_iter() + .min_by_key(|candidate| { + datafusion_strsim::levenshtein(&candidate.to_lowercase(), &target) + }) + .expect("No candidates provided.") // Panic if `candidates` argument is empty +} + +/// Suggest a valid function based on an invalid input function name +pub fn suggest_valid_function(input_function_name: &str, is_window_func: bool) -> String { + let valid_funcs = if is_window_func { + // All aggregate functions and builtin window functions + AggregateFunction::iter() + .map(|func| func.to_string()) + .chain(BuiltInWindowFunction::iter().map(|func| func.to_string())) + .collect() + } else { + // All scalar functions and aggregate functions + BuiltinScalarFunction::iter() + .map(|func| func.to_string()) + .chain(AggregateFunction::iter().map(|func| func.to_string())) + .collect() + }; + find_closest_match(valid_funcs, input_function_name) +} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 45777e09d4bb9..59781e16566b9 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -37,7 +37,7 @@ pub mod expr_rewriter; pub mod expr_schema; pub mod field_util; pub mod function; -mod function_err; +pub mod function_err; mod literal; pub mod logical_plan; mod nullif; diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index 1bae3a162e509..ac8a731a1611c 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -28,6 +28,7 @@ use arrow::datatypes::DataType; use datafusion_common::{DataFusionError, Result}; use std::sync::Arc; use std::{fmt, str::FromStr}; +use strum_macros::EnumIter; /// WindowFunction #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -73,7 +74,7 @@ impl fmt::Display for WindowFunction { } /// An aggregate function that is part of a built-in window function -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)] pub enum BuiltInWindowFunction { /// number of the current row within its partition, counting from 1 RowNumber, diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 70489203b2000..4f45ed60c8fe1 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -18,6 +18,7 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{DFSchema, DataFusionError, Result}; use datafusion_expr::expr::{ScalarFunction, ScalarUDF}; +use datafusion_expr::function_err::suggest_valid_function; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::window_frame::regularize; use datafusion_expr::{ @@ -56,7 +57,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // If function is a window function (it has an OVER clause), // it shouldn't have ordering requirement as function argument // required ordering should be defined in OVER clause. - if !function.order_by.is_empty() && function.over.is_some() { + let is_function_window = function.over.is_some(); + if !function.order_by.is_empty() && is_function_window { return Err(DataFusionError::Plan( "Aggregate ORDER BY is not implemented for window functions".to_string(), )); @@ -84,73 +86,88 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { WindowFrame::new(!order_by.is_empty()) }; - let fun = self.find_window_func(&name)?; - let expr = match fun { - WindowFunction::AggregateFunction(aggregate_fun) => { - let (aggregate_fun, args) = self.aggregate_fn_to_expr( - aggregate_fun, - function.args, - schema, - planner_context, - )?; - - Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(aggregate_fun), - args, + if let Ok(fun) = self.find_window_func(&name) { + let expr = match fun { + WindowFunction::AggregateFunction(aggregate_fun) => { + let (aggregate_fun, args) = self.aggregate_fn_to_expr( + aggregate_fun, + function.args, + schema, + planner_context, + )?; + + Expr::WindowFunction(expr::WindowFunction::new( + WindowFunction::AggregateFunction(aggregate_fun), + args, + partition_by, + order_by, + window_frame, + )) + } + _ => Expr::WindowFunction(expr::WindowFunction::new( + fun, + self.function_args_to_expr( + function.args, + schema, + planner_context, + )?, partition_by, order_by, window_frame, - )) - } - _ => Expr::WindowFunction(expr::WindowFunction::new( + )), + }; + return Ok(expr); + } + } else { + // next, aggregate built-ins + if let Ok(fun) = AggregateFunction::from_str(&name) { + let distinct = function.distinct; + let order_by = self.order_by_to_sort_expr( + &function.order_by, + schema, + planner_context, + )?; + let order_by = (!order_by.is_empty()).then_some(order_by); + let (fun, args) = self.aggregate_fn_to_expr( fun, - self.function_args_to_expr(function.args, schema, planner_context)?, - partition_by, - order_by, - window_frame, - )), + function.args, + schema, + planner_context, + )?; + return Ok(Expr::AggregateFunction(expr::AggregateFunction::new( + fun, args, distinct, None, order_by, + ))); }; - return Ok(expr); - } - // next, aggregate built-ins - if let Ok(fun) = AggregateFunction::from_str(&name) { - let distinct = function.distinct; - let order_by = - self.order_by_to_sort_expr(&function.order_by, schema, planner_context)?; - let order_by = (!order_by.is_empty()).then_some(order_by); - let (fun, args) = - self.aggregate_fn_to_expr(fun, function.args, schema, planner_context)?; - return Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, args, distinct, None, order_by, - ))); - }; - - // finally, user-defined functions (UDF) and UDAF - if let Some(fm) = self.schema_provider.get_function_meta(&name) { - let args = - self.function_args_to_expr(function.args, schema, planner_context)?; - return Ok(Expr::ScalarUDF(ScalarUDF::new(fm, args))); - } + // finally, user-defined functions (UDF) and UDAF + if let Some(fm) = self.schema_provider.get_function_meta(&name) { + let args = + self.function_args_to_expr(function.args, schema, planner_context)?; + return Ok(Expr::ScalarUDF(ScalarUDF::new(fm, args))); + } - // User defined aggregate functions - if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) { - let args = - self.function_args_to_expr(function.args, schema, planner_context)?; - return Ok(Expr::AggregateUDF(expr::AggregateUDF::new( - fm, args, None, None, - ))); - } + // User defined aggregate functions + if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) { + let args = + self.function_args_to_expr(function.args, schema, planner_context)?; + return Ok(Expr::AggregateUDF(expr::AggregateUDF::new( + fm, args, None, None, + ))); + } - // Special case arrow_cast (as its type is dependent on its argument value) - if name == ARROW_CAST_NAME { - let args = - self.function_args_to_expr(function.args, schema, planner_context)?; - return super::arrow_cast::create_arrow_cast(args, schema); + // Special case arrow_cast (as its type is dependent on its argument value) + if name == ARROW_CAST_NAME { + let args = + self.function_args_to_expr(function.args, schema, planner_context)?; + return super::arrow_cast::create_arrow_cast(args, schema); + } } // Could not find the relevant function, so return an error - Err(DataFusionError::Plan(format!("Invalid function '{name}'"))) + let suggested_func_name = suggest_valid_function(&name, is_function_window); + Err(DataFusionError::Plan(format!( + "Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?" + ))) } pub(super) fn sql_named_function_to_expr(