Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions datafusion/common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,74 @@ impl<T: ?Sized> DataPtr for Arc<T> {
}
}

/// Adopted from strsim-rs for string similarity metrics
pub mod datafusion_strsim {
// Source: https://github.com/dguo/strsim-rs/blob/master/src/lib.rs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 thank you for the link

// License: https://github.com/dguo/strsim-rs/blob/master/LICENSE
use std::cmp::min;
use std::str::Chars;

struct StringWrapper<'a>(&'a str);

impl<'a, 'b> IntoIterator for &'a StringWrapper<'b> {
type Item = char;
type IntoIter = Chars<'b>;

fn into_iter(self) -> Self::IntoIter {
self.0.chars()
}
}

/// Calculates the minimum number of insertions, deletions, and substitutions
/// required to change one sequence into the other.
fn generic_levenshtein<'a, 'b, Iter1, Iter2, Elem1, Elem2>(
a: &'a Iter1,
b: &'b Iter2,
) -> usize
where
&'a Iter1: IntoIterator<Item = Elem1>,
&'b Iter2: IntoIterator<Item = Elem2>,
Elem1: PartialEq<Elem2>,
{
let b_len = b.into_iter().count();

if a.into_iter().next().is_none() {
return b_len;
}

let mut cache: Vec<usize> = (1..b_len + 1).collect();

let mut result = 0;

for (i, a_elem) in a.into_iter().enumerate() {
result = i + 1;
let mut distance_b = i;

for (j, b_elem) in b.into_iter().enumerate() {
let cost = if a_elem == b_elem { 0usize } else { 1usize };
let distance_a = distance_b + cost;
distance_b = cache[j];
result = min(result + 1, min(distance_a, distance_b + 1));
cache[j] = result;
}
}

result
}

/// Calculates the minimum number of insertions, deletions, and substitutions
/// required to change one string into the other.
///
/// ```
/// use datafusion_common::utils::datafusion_strsim::levenshtein;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

///
/// assert_eq!(3, levenshtein("kitten", "sitting"));
/// ```
pub fn levenshtein(a: &str, b: &str) -> usize {
generic_levenshtein(&StringWrapper(a), &StringWrapper(b))
}
}

#[cfg(test)]
mod tests {
use arrow::array::Float64Array;
Expand Down
14 changes: 6 additions & 8 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2257,10 +2257,9 @@ mod tests {
let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t")
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"Error during planning: Invalid function \'my_func\'"
);
assert!(err
.to_string()
.contains("Error during planning: Invalid function \'my_func\'"));

// Can call it if you put quotes
let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?;
Expand Down Expand Up @@ -2304,10 +2303,9 @@ mod tests {
let err = plan_and_collect(&ctx, "SELECT MY_AVG(i) FROM t")
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"Error during planning: Invalid function \'my_avg\'"
);
assert!(err
.to_string()
.contains("Error during planning: Invalid function \'my_avg\'"));

// Can call it if you put quotes
let result = plan_and_collect(&ctx, "SELECT \"MY_AVG\"(i) FROM t").await?;
Expand Down
7 changes: 3 additions & 4 deletions datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,9 @@ async fn case_sensitive_identifiers_aggregates() {
let err = plan_and_collect(&ctx, "SELECT \"MAX\"(i) FROM t")
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"Error during planning: Invalid function 'MAX'"
);
assert!(err
.to_string()
.contains("Error during planning: Invalid function 'MAX'"));

let results = plan_and_collect(&ctx, "SELECT \"max\"(i) FROM t")
.await
Expand Down
7 changes: 3 additions & 4 deletions datafusion/core/tests/sql/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,9 @@ async fn case_sensitive_identifiers_functions() {
let err = plan_and_collect(&ctx, "SELECT \"SQRT\"(i) FROM t")
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"Error during planning: Invalid function 'SQRT'"
);
assert!(err
.to_string()
.contains("Error during planning: Invalid function 'SQRT'"));

let results = plan_and_collect(&ctx, "SELECT \"sqrt\"(i) FROM t")
.await
Expand Down
40 changes: 40 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,43 @@ SELECT length(c1) FROM test

statement ok
drop table test

#
# Testing error message for wrong function name
#

statement ok
CREATE TABLE test(
v1 Int,
v2 Int
) as VALUES
(1, 10),
(2, 20),
(3, 30);

# Scalar function
statement error Did you mean 'arrow_typeof'?
SELECT arrowtypeof(v1) from test;

# Scalar function
statement error Did you mean 'to_timestamp_seconds'?
SELECT to_TIMESTAMPS_second(v2) from test;

# Aggregate function
statement error Did you mean 'COUNT'?
SELECT counter(*) from test;

# Aggregate function
statement error Did you mean 'STDDEV'?
SELECT STDEV(v1) from test;

# Window function
statement error Did you mean 'SUM'?
SELECT v1, v2, SUMM(v2) OVER(ORDER BY v1) from test;

# Window function
statement error Did you mean 'ROW_NUMBER'?
SELECT v1, v2, ROWNUMBER() OVER(ORDER BY v1) from test;

statement ok
drop table test
3 changes: 2 additions & 1 deletion datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ use arrow::datatypes::{DataType, Field};
use datafusion_common::{DataFusionError, Result};
use std::sync::Arc;
use std::{fmt, str::FromStr};
use strum_macros::EnumIter;

/// Enum of all built-in aggregate functions
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)]
pub enum AggregateFunction {
/// count
Count,
Expand Down
36 changes: 35 additions & 1 deletion datafusion/expr/src/function_err.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@
//! ```

use crate::function::signature;
use crate::{BuiltinScalarFunction, TypeSignature};
use crate::{
AggregateFunction, BuiltInWindowFunction, BuiltinScalarFunction, TypeSignature,
};
use arrow::datatypes::DataType;
use datafusion_common::utils::datafusion_strsim;
use strum::IntoEnumIterator;

impl TypeSignature {
fn to_string_repr(&self) -> Vec<String> {
Expand Down Expand Up @@ -89,3 +93,33 @@ pub fn generate_signature_error_msg(
fun, join_types(input_expr_types, ", "), candidate_signatures
)
}

/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve)
/// Input `candidates` must not be empty otherwise it will panic
fn find_closest_match(candidates: Vec<String>, target: &str) -> String {
let target = target.to_lowercase();
candidates
.into_iter()
.min_by_key(|candidate| {
datafusion_strsim::levenshtein(&candidate.to_lowercase(), &target)
})
.expect("No candidates provided.") // Panic if `candidates` argument is empty
}

/// Suggest a valid function based on an invalid input function name
pub fn suggest_valid_function(input_function_name: &str, is_window_func: bool) -> String {
let valid_funcs = if is_window_func {
// All aggregate functions and builtin window functions
AggregateFunction::iter()
.map(|func| func.to_string())
.chain(BuiltInWindowFunction::iter().map(|func| func.to_string()))
.collect()
} else {
// All scalar functions and aggregate functions
BuiltinScalarFunction::iter()
.map(|func| func.to_string())
.chain(AggregateFunction::iter().map(|func| func.to_string()))
.collect()
};
find_closest_match(valid_funcs, input_function_name)
}
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub mod expr_rewriter;
pub mod expr_schema;
pub mod field_util;
pub mod function;
mod function_err;
pub mod function_err;
mod literal;
pub mod logical_plan;
mod nullif;
Expand Down
3 changes: 2 additions & 1 deletion datafusion/expr/src/window_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use arrow::datatypes::DataType;
use datafusion_common::{DataFusionError, Result};
use std::sync::Arc;
use std::{fmt, str::FromStr};
use strum_macros::EnumIter;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


/// WindowFunction
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -73,7 +74,7 @@ impl fmt::Display for WindowFunction {
}

/// An aggregate function that is part of a built-in window function
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)]
pub enum BuiltInWindowFunction {
/// number of the current row within its partition, counting from 1
RowNumber,
Expand Down
Loading