From df59c96c3470cfde7ec8a63ac1c557d8c469fdd9 Mon Sep 17 00:00:00 2001 From: cli2 Date: Fri, 10 Nov 2023 16:17:21 +0800 Subject: [PATCH 01/11] feat:implement posgres style 'overlay' string function --- datafusion/expr/src/built_in_function.rs | 16 ++++- datafusion/expr/src/expr_fn.rs | 6 ++ datafusion/physical-expr/src/functions.rs | 11 ++++ .../physical-expr/src/string_expressions.rs | 61 +++++++++++++++++++ datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 1 + datafusion/proto/src/generated/prost.rs | 3 + .../proto/src/logical_plan/from_proto.rs | 18 ++++-- datafusion/proto/src/logical_plan/to_proto.rs | 1 + 9 files changed, 113 insertions(+), 5 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index f3f52e9dafb6b..70a5d75c6861e 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -288,6 +288,8 @@ pub enum BuiltinScalarFunction { RegexpMatch, /// arrow_typeof ArrowTypeof, + /// overlay + OverLay, } /// Maps the sql function name to `BuiltinScalarFunction` @@ -449,6 +451,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Struct => Volatility::Immutable, BuiltinScalarFunction::FromUnixtime => Volatility::Immutable, BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable, + BuiltinScalarFunction::OverLay => Volatility::Immutable, // Stable builtin functions BuiltinScalarFunction::Now => Volatility::Stable, @@ -804,6 +807,10 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Abs => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::OverLay => { + utf8_to_str_type(&input_expr_types[0], "overlay") + } + BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin | BuiltinScalarFunction::Atan @@ -1248,7 +1255,13 @@ impl BuiltinScalarFunction { } BuiltinScalarFunction::ArrowTypeof => Signature::any(1, self.volatility()), BuiltinScalarFunction::Abs => Signature::any(1, self.volatility()), - + BuiltinScalarFunction::OverLay => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + ], + self.volatility(), + ), BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin | BuiltinScalarFunction::Atan @@ -1505,6 +1518,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { ], BuiltinScalarFunction::Cardinality => &["cardinality"], BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], + BuiltinScalarFunction::OverLay => &["overlay"], // struct functions BuiltinScalarFunction::Struct => &["struct"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 5a60c2470c95b..4ae599278f236 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -817,6 +817,11 @@ nary_scalar_expr!( "concatenates several strings, placing a seperator between each one" ); nary_scalar_expr!(Concat, concat_expr, "concatenates several strings"); +nary_scalar_expr!( + OverLay, + overlay, + "replace the substring of string that starts at the start'th character and extends for count characters with new substring" +); // date functions scalar_expr!(DatePart, date_part, part date, "extracts a subfield from the date"); @@ -1146,6 +1151,7 @@ mod test { test_nary_scalar_expr!(MakeArray, array, input); test_unary_scalar_expr!(ArrowTypeof, arrow_typeof); + test_nary_scalar_expr!(OverLay, overlay, string, characters, from, len); } #[test] diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index f14bad093ac74..89b7c0c3d339e 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -927,6 +927,17 @@ pub fn create_physical_fun( }), BuiltinScalarFunction::Upper => Arc::new(string_expressions::upper), BuiltinScalarFunction::Uuid => Arc::new(string_expressions::uuid), + BuiltinScalarFunction::OverLay => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::overlay::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::overlay::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function overlay", + ))), + }), _ => { return internal_err!( "create_physical_fun: Unsupported scalar function {fun:?}" diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index e6a3d5c331a54..6eb99f76857d3 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -553,11 +553,55 @@ pub fn uuid(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(array))) } +/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2) +/// Replaces a substring of string1 with string2 starting at the integer bit +/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas +/// datafusion overlay('Txxxxas', 'hom', 2, 4) -> Thomas +pub fn overlay(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + let len_num = as_int64_array(&args[3])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .zip(pos_num.iter()) + .zip(len_num.iter()) + .map(|(((string, characters), start_pos), len)| { + match (string, characters, start_pos, len) { + (Some(string), Some(characters), Some(start_pos), Some(len)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = len.min(string_len as i64); + let mut res = String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) +} + #[cfg(test)] mod tests { use crate::string_expressions; use arrow::{array::Int32Array, datatypes::Int32Type}; + use arrow_array::Int64Array; use super::*; @@ -599,4 +643,21 @@ mod tests { Ok(()) } + + #[test] + fn to_overlay() -> Result<()> { + let string = + Arc::new(StringArray::from(vec!["123", "abcdefg", "xyz", "Txxxxas"])); + let replace_string = + Arc::new(StringArray::from(vec!["abc", "qwertyasdfg", "ijk", "hom"])); + let start = Arc::new(Int64Array::from(vec![4, 1, 1, 2])); // start + let end = Arc::new(Int64Array::from(vec![5, 7, 2, 4])); // replace len + + let res = overlay::(&[string, replace_string, start, end]).unwrap(); + let result = as_generic_string_array::(&res).unwrap(); + let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", "Thomas"]); + assert_eq!(&expected, result); + + Ok(()) + } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index bc6de2348e8d5..8ebdf82d68252 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -621,6 +621,7 @@ enum ScalarFunction { ArrayPopBack = 116; StringToArray = 117; ToTimestampNanos = 118; + OverLay = 119; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 659a25f9fa356..78cdfb1445b1a 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -21019,6 +21019,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayPopBack" => Ok(ScalarFunction::ArrayPopBack), "StringToArray" => Ok(ScalarFunction::StringToArray), "ToTimestampNanos" => Ok(ScalarFunction::ToTimestampNanos), + "OverLay" => OK(ScalaFunction::OverLay), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 75050e9d3dfad..75edc9bbefd0b 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2539,6 +2539,7 @@ pub enum ScalarFunction { ArrayPopBack = 116, StringToArray = 117, ToTimestampNanos = 118, + OverLay = 119, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2666,6 +2667,7 @@ impl ScalarFunction { ScalarFunction::ArrayPopBack => "ArrayPopBack", ScalarFunction::StringToArray => "StringToArray", ScalarFunction::ToTimestampNanos => "ToTimestampNanos", + ScalarFunction::OverLay => "OverLay", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2790,6 +2792,7 @@ impl ScalarFunction { "ArrayPopBack" => Some(Self::ArrayPopBack), "StringToArray" => Some(Self::StringToArray), "ToTimestampNanos" => Some(Self::ToTimestampNanos), + "OverLay" => Some(Self::OverLay), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index cdb0fe9bda7fb..2b1010ac4950e 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -50,10 +50,10 @@ use datafusion_expr::{ expr::{self, InList, Sort, WindowFunction}, factorial, floor, from_unixtime, gcd, isnan, iszero, lcm, left, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, pi, power, radians, - random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, - rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, - starts_with, strpos, substr, substring, tan, tanh, to_hex, to_timestamp_micros, + lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, + radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, + round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, + sqrt, starts_with, strpos, substr, substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, uuid, window_frame::regularize, @@ -540,6 +540,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Isnan => Self::Isnan, ScalarFunction::Iszero => Self::Iszero, ScalarFunction::ArrowTypeof => Self::ArrowTypeof, + ScalarFunction::OverLay => Self::OverLay, } } } @@ -1648,6 +1649,15 @@ pub fn parse_expr( _ => Err(proto_error( "Protobuf deserialization error: Unsupported scalar function", )), + ScalarFunction::OverLay => Ok(overlay( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), + _ => Err(proto_error( + "Protobuf deserialization error: Unsupported scalar function", + )), } } ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args }) => { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 687b73cfc886f..9209d97f8bb7b 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1539,6 +1539,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Isnan => Self::Isnan, BuiltinScalarFunction::Iszero => Self::Iszero, BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, + BuiltinScalarFunction::OverLay => Self::OverLay, }; Ok(scalar_function) From 35416d2775d9044c3564df91ab5898196af2e0ee Mon Sep 17 00:00:00 2001 From: cli2 Date: Fri, 10 Nov 2023 16:42:02 +0800 Subject: [PATCH 02/11] code format --- datafusion/proto/src/generated/pbjson.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 78cdfb1445b1a..e3cd129a9f3a0 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -21019,7 +21019,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayPopBack" => Ok(ScalarFunction::ArrayPopBack), "StringToArray" => Ok(ScalarFunction::StringToArray), "ToTimestampNanos" => Ok(ScalarFunction::ToTimestampNanos), - "OverLay" => OK(ScalaFunction::OverLay), + "OverLay" => OK(ScalarFunction::OverLay), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } From 73123d4a5a7b340151b14f0cd829d14ed3ef5b3e Mon Sep 17 00:00:00 2001 From: cli2 Date: Fri, 10 Nov 2023 16:49:59 +0800 Subject: [PATCH 03/11] code format --- datafusion/proto/src/generated/pbjson.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index e3cd129a9f3a0..6a5e1602a173a 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20730,6 +20730,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayPopBack => "ArrayPopBack", Self::StringToArray => "StringToArray", Self::ToTimestampNanos => "ToTimestampNanos", + Self::OverLay => "OverLay", }; serializer.serialize_str(variant) } @@ -20860,6 +20861,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayPopBack", "StringToArray", "ToTimestampNanos", + "OverLay", ]; struct GeneratedVisitor; @@ -21019,7 +21021,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayPopBack" => Ok(ScalarFunction::ArrayPopBack), "StringToArray" => Ok(ScalarFunction::StringToArray), "ToTimestampNanos" => Ok(ScalarFunction::ToTimestampNanos), - "OverLay" => OK(ScalarFunction::OverLay), + "OverLay" => Ok(ScalarFunction::OverLay), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } From ab98011bc0960c2c062e5a82bd1a5fb3c65af3b4 Mon Sep 17 00:00:00 2001 From: cli2 Date: Fri, 10 Nov 2023 17:02:18 +0800 Subject: [PATCH 04/11] code format --- datafusion/proto/src/logical_plan/from_proto.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 2b1010ac4950e..acb88c482af9e 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -55,7 +55,7 @@ use datafusion_expr::{ round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, starts_with, strpos, substr, substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos, to_timestamp_seconds, translate, trim, - trunc, upper, uuid, + trunc, upper, uuid window_frame::regularize, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, @@ -1646,9 +1646,6 @@ pub fn parse_expr( )), ScalarFunction::Isnan => Ok(isnan(parse_expr(&args[0], registry)?)), ScalarFunction::Iszero => Ok(iszero(parse_expr(&args[0], registry)?)), - _ => Err(proto_error( - "Protobuf deserialization error: Unsupported scalar function", - )), ScalarFunction::OverLay => Ok(overlay( args.to_owned() .iter() From d1ebff21440f5e1d04355f94381ecca39ca03e3a Mon Sep 17 00:00:00 2001 From: cli2 Date: Fri, 10 Nov 2023 17:09:27 +0800 Subject: [PATCH 05/11] code format --- datafusion/proto/src/logical_plan/from_proto.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index acb88c482af9e..64c9236d6047f 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -55,7 +55,7 @@ use datafusion_expr::{ round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, starts_with, strpos, substr, substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos, to_timestamp_seconds, translate, trim, - trunc, upper, uuid + trunc, upper, uuid, window_frame::regularize, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, From 46842313dbfb1ca4512c49502e721061cd378233 Mon Sep 17 00:00:00 2001 From: cli2 Date: Mon, 13 Nov 2023 15:35:07 +0800 Subject: [PATCH 06/11] add sql slt test --- datafusion/expr/src/built_in_function.rs | 2 + .../physical-expr/src/string_expressions.rs | 113 +++++++++++++----- .../proto/src/logical_plan/from_proto.rs | 8 +- datafusion/sql/src/expr/mod.rs | 40 ++++++- .../sqllogictest/test_files/functions.slt | 46 +++++-- 5 files changed, 160 insertions(+), 49 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index e4329e0e7e30f..a00cf8b968048 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -1264,6 +1264,8 @@ impl BuiltinScalarFunction { vec![ Exact(vec![Utf8, Utf8, Int64, Int64]), Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), ], self.volatility(), ), diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index 6eb99f76857d3..80d581aa403d2 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -558,42 +558,89 @@ pub fn uuid(args: &[ColumnarValue]) -> Result { /// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas /// datafusion overlay('Txxxxas', 'hom', 2, 4) -> Thomas pub fn overlay(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - let pos_num = as_int64_array(&args[2])?; - let len_num = as_int64_array(&args[3])?; + match args.len() { + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_num = as_int64_array(&args[2])?; - let result = string_array - .iter() - .zip(characters_array.iter()) - .zip(pos_num.iter()) - .zip(len_num.iter()) - .map(|(((string, characters), start_pos), len)| { - match (string, characters, start_pos, len) { - (Some(string), Some(characters), Some(start_pos), Some(len)) => { - let string_len = string.chars().count(); - let characters_len = characters.chars().count(); - let replace_len = len.min(string_len as i64); - let mut res = String::with_capacity(string_len.max(characters_len)); - - //as sql replace index start from 1 while string index start from 0 - if start_pos > 1 && start_pos - 1 < string_len as i64 { - let start = (start_pos - 1) as usize; - res.push_str(&string[..start]); + let result = string_array + .iter() + .zip(characters_array.iter()) + .zip(pos_num.iter()) + .map(|((string, characters), start_pos)| { + match (string, characters, start_pos) { + (Some(string), Some(characters), Some(start_pos)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = characters_len as i64; + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), } - res.push_str(characters); - // if start + replace_len - 1 >= string_length, just to string end - if start_pos + replace_len - 1 < string_len as i64 { - let end = (start_pos + replace_len - 1) as usize; - res.push_str(&string[end..]); + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + 4 => { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + let len_num = as_int64_array(&args[3])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .zip(pos_num.iter()) + .zip(len_num.iter()) + .map(|(((string, characters), start_pos), len)| { + match (string, characters, start_pos, len) { + (Some(string), Some(characters), Some(start_pos), Some(len)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = len.min(string_len as i64); + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), } - Ok(Some(res)) - } - _ => Ok(None), - } - }) - .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + other => { + internal_err!( + "overlay was called with {other} arguments. It requires 3 or 4." + ) + } + } } #[cfg(test)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 93597e8215b1a..07697fd1ff226 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -51,10 +51,10 @@ use datafusion_expr::{ factorial, flatten, floor, from_unixtime, gcd, isnan, iszero, lcm, left, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, radians, - random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, - rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, - starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, tanh, + lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, + radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, + round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, + sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, uuid, window_frame::regularize, diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 1cf0fc133f040..7fa16ced39da4 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -459,7 +459,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, planner_context, ), - + SQLExpr::Overlay { + expr, + overlay_what, + overlay_from, + overlay_for, + } => self.sql_overlay_to_expr( + *expr, + *overlay_what, + *overlay_from, + overlay_for, + schema, + planner_context, + ), SQLExpr::Nested(e) => { self.sql_expr_to_logical_expr(*e, schema, planner_context) } @@ -645,6 +657,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) } + fn sql_overlay_to_expr( + &self, + expr: SQLExpr, + overlay_what: SQLExpr, + overlay_from: SQLExpr, + overlay_for: Option>, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let fun = BuiltinScalarFunction::OverLay; + let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; + let what_arg = + self.sql_expr_to_logical_expr(overlay_what, schema, planner_context)?; + let from_arg = + self.sql_expr_to_logical_expr(overlay_from, schema, planner_context)?; + let args = match overlay_for { + Some(for_expr) => { + let for_expr = + self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?; + vec![arg, what_arg, from_arg, for_expr] + } + None => vec![arg, what_arg, from_arg], + }; + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) + } + fn sql_agg_with_filter_to_expr( &self, expr: SQLExpr, diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 2054752cc59cd..91e13bcba372a 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -532,7 +532,7 @@ statement ok drop table simple_struct_test # create aggregate_test_100 table for functions test -statement ok +statement error DataFusion error: IO error: No such file or directory \(os error 2\) CREATE EXTERNAL TABLE aggregate_test_100 ( c1 VARCHAR NOT NULL, c2 TINYINT NOT NULL, @@ -554,22 +554,16 @@ LOCATION '../../testing/data/csv/aggregate_test_100.csv' # sqrt_f32_vs_f64 -query R +query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found SELECT avg(sqrt(c11)) FROM aggregate_test_100 ----- -0.658440848589 -query R +query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found SELECT avg(CAST(sqrt(c11) AS double)) FROM aggregate_test_100 ----- -0.658440848589 -query R +query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found SELECT avg(sqrt(CAST(c11 AS double))) FROM aggregate_test_100 ----- -0.658440848342 -statement ok +statement error DataFusion error: Execution error: Table 'aggregate_test_100' doesn't exist\. drop table aggregate_test_100 @@ -815,3 +809,33 @@ SELECT products.* REPLACE (price*2 AS price, product_id+1000 AS product_id) FROM 1002 OldBrand Product 2 59.98 1003 OldBrand Product 3 79.98 1004 OldBrand Product 4 99.98 + +#overlay tests +statement ok +CREATE TABLE over_test( + str TEXT, + characters TEXT, + pos INT, + len INT +) as VALUES + ('123', 'abc', 4, 5), + ('abcdefg', 'qwertyasdfg', 1, 7), + ('xyz', 'ijk', 1, 2), + ('Txxxxas', 'hom', 2, 4) +; + +query T +SELECT overlay(str placing characters from pos for len) from over_test +---- +abc +qwertyasdfg +ijkz +Thomas + +query T +SELECT overlay(str placing characters from pos) from over_test +---- +abc +qwertyasdfg +ijk +Thomxas From 3b9233f00b3ae8093f80ce581767d05f31f2a775 Mon Sep 17 00:00:00 2001 From: cli2 Date: Mon, 13 Nov 2023 15:42:58 +0800 Subject: [PATCH 07/11] fix modify other case issue --- datafusion/sqllogictest/test_files/functions.slt | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 91e13bcba372a..62b0ba8d910a7 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -532,7 +532,7 @@ statement ok drop table simple_struct_test # create aggregate_test_100 table for functions test -statement error DataFusion error: IO error: No such file or directory \(os error 2\) +statement ok CREATE EXTERNAL TABLE aggregate_test_100 ( c1 VARCHAR NOT NULL, c2 TINYINT NOT NULL, @@ -554,16 +554,22 @@ LOCATION '../../testing/data/csv/aggregate_test_100.csv' # sqrt_f32_vs_f64 -query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found +query R SELECT avg(sqrt(c11)) FROM aggregate_test_100 +---- +0.658440848589 -query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found +query R SELECT avg(CAST(sqrt(c11) AS double)) FROM aggregate_test_100 +---- +0.658440848589 -query error DataFusion error: Error during planning: table 'datafusion\.public\.aggregate_test_100' not found +query R SELECT avg(sqrt(CAST(c11 AS double))) FROM aggregate_test_100 +---- +0.658440848342 -statement error DataFusion error: Execution error: Table 'aggregate_test_100' doesn't exist\. +statement ok drop table aggregate_test_100 From e148cd32b4df3c1cf0993b61c1195fad9047deb5 Mon Sep 17 00:00:00 2001 From: cli2 Date: Mon, 13 Nov 2023 15:51:29 +0800 Subject: [PATCH 08/11] add test expr --- datafusion/expr/src/expr_fn.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c31209a2985be..4d17b0925c6c2 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1177,7 +1177,8 @@ mod test { test_nary_scalar_expr!(MakeArray, array, input); test_unary_scalar_expr!(ArrowTypeof, arrow_typeof); - test_nary_scalar_expr!(OverLay, overlay, string, characters, from, len); + test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len); + test_nary_scalar_expr!(OverLay, overlay, string, characters, position); } #[test] From e59ab5f212a94b5b07aff487e65aa37da9b2ba0f Mon Sep 17 00:00:00 2001 From: cli2 Date: Mon, 13 Nov 2023 15:54:08 +0800 Subject: [PATCH 09/11] add annotation --- datafusion/physical-expr/src/string_expressions.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index 80d581aa403d2..7e954fdcfdc48 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -556,7 +556,7 @@ pub fn uuid(args: &[ColumnarValue]) -> Result { /// OVERLAY(string1 PLACING string2 FROM integer FOR integer2) /// Replaces a substring of string1 with string2 starting at the integer bit /// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas -/// datafusion overlay('Txxxxas', 'hom', 2, 4) -> Thomas +/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead pub fn overlay(args: &[ArrayRef]) -> Result { match args.len() { 3 => { From 6ad3b6ff35b5e8c4e7e4a6e541c1eca351e5057b Mon Sep 17 00:00:00 2001 From: cli2 Date: Mon, 13 Nov 2023 16:44:32 +0800 Subject: [PATCH 10/11] add overlay function sql reference doc --- docs/source/user-guide/sql/scalar_functions.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index be05084fb2491..1750bbf3480cd 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -635,6 +635,7 @@ nullif(expression1, expression2) - [trim](#trim) - [upper](#upper) - [uuid](#uuid) +- [overlay](#overlay) ### `ascii` @@ -1120,6 +1121,21 @@ Returns UUID v4 string value which is unique per row. uuid() ``` +### `overlay` + +Returns the string which is replaced by another string from the specified position and specified count length + +``` +overlay(str PLACING substr FROM pos [FOR count]) +``` + +#### Arguments + +- **str**: String expression to operate on. +- **substr**: the string to replace part of str. +- **pos**: the start position to replace of str. +- **count**: the count of characters to be replaced from start position of str.If not specified, will use substr length instead. + ## Binary String Functions - [decode](#decode) From 0ce524eda416870d7bd4f5afb100e70757f68123 Mon Sep 17 00:00:00 2001 From: cli2 Date: Tue, 14 Nov 2023 10:22:03 +0800 Subject: [PATCH 11/11] add sql case and format doc --- datafusion/sqllogictest/test_files/functions.slt | 14 +++++++++++++- docs/source/user-guide/sql/scalar_functions.md | 5 +++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 62b0ba8d910a7..8f42304384802 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -827,7 +827,11 @@ CREATE TABLE over_test( ('123', 'abc', 4, 5), ('abcdefg', 'qwertyasdfg', 1, 7), ('xyz', 'ijk', 1, 2), - ('Txxxxas', 'hom', 2, 4) + ('Txxxxas', 'hom', 2, 4), + (NULL, 'hom', 2, 4), + ('Txxxxas', 'hom', NULL, 4), + ('Txxxxas', 'hom', 2, NULL), + ('Txxxxas', NULL, 2, 4) ; query T @@ -837,6 +841,10 @@ abc qwertyasdfg ijkz Thomas +NULL +NULL +NULL +NULL query T SELECT overlay(str placing characters from pos) from over_test @@ -845,3 +853,7 @@ abc qwertyasdfg ijk Thomxas +NULL +NULL +Thomxas +NULL diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 19429fa8b2b0b..099c90312227f 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1123,7 +1123,8 @@ uuid() ### `overlay` -Returns the string which is replaced by another string from the specified position and specified count length +Returns the string which is replaced by another string from the specified position and specified count length. +For example, `overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas` ``` overlay(str PLACING substr FROM pos [FOR count]) @@ -1134,7 +1135,7 @@ overlay(str PLACING substr FROM pos [FOR count]) - **str**: String expression to operate on. - **substr**: the string to replace part of str. - **pos**: the start position to replace of str. -- **count**: the count of characters to be replaced from start position of str.If not specified, will use substr length instead. +- **count**: the count of characters to be replaced from start position of str. If not specified, will use substr length instead. ## Binary String Functions