diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 0d2c1f2e3cb7e..77c64128e1562 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -292,6 +292,8 @@ pub enum BuiltinScalarFunction { RegexpMatch, /// arrow_typeof ArrowTypeof, + /// overlay + OverLay, } /// Maps the sql function name to `BuiltinScalarFunction` @@ -455,6 +457,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Struct => Volatility::Immutable, BuiltinScalarFunction::FromUnixtime => Volatility::Immutable, BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable, + BuiltinScalarFunction::OverLay => Volatility::Immutable, // Stable builtin functions BuiltinScalarFunction::Now => Volatility::Stable, @@ -812,6 +815,10 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Abs => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::OverLay => { + utf8_to_str_type(&input_expr_types[0], "overlay") + } + BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin | BuiltinScalarFunction::Atan @@ -1258,7 +1265,15 @@ impl BuiltinScalarFunction { } BuiltinScalarFunction::ArrowTypeof => Signature::any(1, self.volatility()), BuiltinScalarFunction::Abs => Signature::any(1, self.volatility()), - + BuiltinScalarFunction::OverLay => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + self.volatility(), + ), BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin | BuiltinScalarFunction::Atan @@ -1517,6 +1532,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::Cardinality => &["cardinality"], BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_intersect"], + BuiltinScalarFunction::OverLay => &["overlay"], // struct functions BuiltinScalarFunction::Struct => &["struct"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 0d920beb416f1..91674cc092e61 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -838,6 +838,11 @@ nary_scalar_expr!( "concatenates several strings, placing a seperator between each one" ); nary_scalar_expr!(Concat, concat_expr, "concatenates several strings"); +nary_scalar_expr!( + OverLay, + overlay, + "replace the substring of string that starts at the start'th character and extends for count characters with new substring" +); // date functions scalar_expr!(DatePart, date_part, part date, "extracts a subfield from the date"); @@ -1174,6 +1179,8 @@ mod test { test_nary_scalar_expr!(MakeArray, array, input); test_unary_scalar_expr!(ArrowTypeof, arrow_typeof); + test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len); + test_nary_scalar_expr!(OverLay, overlay, string, characters, position); } #[test] diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 80c0eaf054fd0..7f8921e86c38e 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -829,6 +829,17 @@ pub fn create_physical_fun( "{input_data_type}" ))))) }), + BuiltinScalarFunction::OverLay => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::overlay::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::overlay::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function overlay", + ))), + }), }) } diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index e6a3d5c331a54..7e954fdcfdc48 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -553,11 +553,102 @@ pub fn uuid(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(array))) } +/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2) +/// Replaces a substring of string1 with string2 starting at the integer bit +/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas +/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead +pub fn overlay(args: &[ArrayRef]) -> Result { + match args.len() { + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .zip(pos_num.iter()) + .map(|((string, characters), start_pos)| { + match (string, characters, start_pos) { + (Some(string), Some(characters), Some(start_pos)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = characters_len as i64; + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + 4 => { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + let len_num = as_int64_array(&args[3])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .zip(pos_num.iter()) + .zip(len_num.iter()) + .map(|(((string, characters), start_pos), len)| { + match (string, characters, start_pos, len) { + (Some(string), Some(characters), Some(start_pos), Some(len)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = len.min(string_len as i64); + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + other => { + internal_err!( + "overlay was called with {other} arguments. It requires 3 or 4." + ) + } + } +} + #[cfg(test)] mod tests { use crate::string_expressions; use arrow::{array::Int32Array, datatypes::Int32Type}; + use arrow_array::Int64Array; use super::*; @@ -599,4 +690,21 @@ mod tests { Ok(()) } + + #[test] + fn to_overlay() -> Result<()> { + let string = + Arc::new(StringArray::from(vec!["123", "abcdefg", "xyz", "Txxxxas"])); + let replace_string = + Arc::new(StringArray::from(vec!["abc", "qwertyasdfg", "ijk", "hom"])); + let start = Arc::new(Int64Array::from(vec![4, 1, 1, 2])); // start + let end = Arc::new(Int64Array::from(vec![5, 7, 2, 4])); // replace len + + let res = overlay::(&[string, replace_string, start, end]).unwrap(); + let result = as_generic_string_array::(&res).unwrap(); + let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", "Thomas"]); + assert_eq!(&expected, result); + + Ok(()) + } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 5d7c570bc1737..d85678a76bf1a 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -636,6 +636,7 @@ enum ScalarFunction { ToTimestampNanos = 118; ArrayIntersect = 119; ArrayUnion = 120; + OverLay = 121; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 12fa73205d49f..64db9137d64f7 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20935,6 +20935,7 @@ impl serde::Serialize for ScalarFunction { Self::ToTimestampNanos => "ToTimestampNanos", Self::ArrayIntersect => "ArrayIntersect", Self::ArrayUnion => "ArrayUnion", + Self::OverLay => "OverLay", }; serializer.serialize_str(variant) } @@ -21067,6 +21068,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ToTimestampNanos", "ArrayIntersect", "ArrayUnion", + "OverLay", ]; struct GeneratedVisitor; @@ -21228,6 +21230,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ToTimestampNanos" => Ok(ScalarFunction::ToTimestampNanos), "ArrayIntersect" => Ok(ScalarFunction::ArrayIntersect), "ArrayUnion" => Ok(ScalarFunction::ArrayUnion), + "OverLay" => Ok(ScalarFunction::OverLay), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 23be5d9088666..131ca11993c1f 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2567,6 +2567,7 @@ pub enum ScalarFunction { ToTimestampNanos = 118, ArrayIntersect = 119, ArrayUnion = 120, + OverLay = 121, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2696,6 +2697,7 @@ impl ScalarFunction { ScalarFunction::ToTimestampNanos => "ToTimestampNanos", ScalarFunction::ArrayIntersect => "ArrayIntersect", ScalarFunction::ArrayUnion => "ArrayUnion", + ScalarFunction::OverLay => "OverLay", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2822,6 +2824,7 @@ impl ScalarFunction { "ToTimestampNanos" => Some(Self::ToTimestampNanos), "ArrayIntersect" => Some(Self::ArrayIntersect), "ArrayUnion" => Some(Self::ArrayUnion), + "OverLay" => Some(Self::OverLay), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 0ecbe05e79032..9ca7bb0e893a3 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -52,10 +52,10 @@ use datafusion_expr::{ factorial, flatten, floor, from_unixtime, gcd, isnan, iszero, lcm, left, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, pi, power, radians, - random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, - rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, - starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, tanh, + lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, + radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, + round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, + sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, uuid, window_frame::regularize, @@ -546,6 +546,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Isnan => Self::Isnan, ScalarFunction::Iszero => Self::Iszero, ScalarFunction::ArrowTypeof => Self::ArrowTypeof, + ScalarFunction::OverLay => Self::OverLay, } } } @@ -1680,6 +1681,12 @@ pub fn parse_expr( parse_expr(&args[1], registry)?, parse_expr(&args[2], registry)?, )), + ScalarFunction::OverLay => Ok(overlay( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), ScalarFunction::StructFun => { Ok(struct_fun(parse_expr(&args[0], registry)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 4c81ab954a71d..974d6c5aaba8f 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1553,6 +1553,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Isnan => Self::Isnan, BuiltinScalarFunction::Iszero => Self::Iszero, BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, + BuiltinScalarFunction::OverLay => Self::OverLay, }; Ok(scalar_function) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 1cf0fc133f040..7fa16ced39da4 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -459,7 +459,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, planner_context, ), - + SQLExpr::Overlay { + expr, + overlay_what, + overlay_from, + overlay_for, + } => self.sql_overlay_to_expr( + *expr, + *overlay_what, + *overlay_from, + overlay_for, + schema, + planner_context, + ), SQLExpr::Nested(e) => { self.sql_expr_to_logical_expr(*e, schema, planner_context) } @@ -645,6 +657,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) } + fn sql_overlay_to_expr( + &self, + expr: SQLExpr, + overlay_what: SQLExpr, + overlay_from: SQLExpr, + overlay_for: Option>, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let fun = BuiltinScalarFunction::OverLay; + let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; + let what_arg = + self.sql_expr_to_logical_expr(overlay_what, schema, planner_context)?; + let from_arg = + self.sql_expr_to_logical_expr(overlay_from, schema, planner_context)?; + let args = match overlay_for { + Some(for_expr) => { + let for_expr = + self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?; + vec![arg, what_arg, from_arg, for_expr] + } + None => vec![arg, what_arg, from_arg], + }; + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) + } + fn sql_agg_with_filter_to_expr( &self, expr: SQLExpr, diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 2054752cc59cd..8f42304384802 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -815,3 +815,45 @@ SELECT products.* REPLACE (price*2 AS price, product_id+1000 AS product_id) FROM 1002 OldBrand Product 2 59.98 1003 OldBrand Product 3 79.98 1004 OldBrand Product 4 99.98 + +#overlay tests +statement ok +CREATE TABLE over_test( + str TEXT, + characters TEXT, + pos INT, + len INT +) as VALUES + ('123', 'abc', 4, 5), + ('abcdefg', 'qwertyasdfg', 1, 7), + ('xyz', 'ijk', 1, 2), + ('Txxxxas', 'hom', 2, 4), + (NULL, 'hom', 2, 4), + ('Txxxxas', 'hom', NULL, 4), + ('Txxxxas', 'hom', 2, NULL), + ('Txxxxas', NULL, 2, 4) +; + +query T +SELECT overlay(str placing characters from pos for len) from over_test +---- +abc +qwertyasdfg +ijkz +Thomas +NULL +NULL +NULL +NULL + +query T +SELECT overlay(str placing characters from pos) from over_test +---- +abc +qwertyasdfg +ijk +Thomxas +NULL +NULL +Thomxas +NULL diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 2959e8202437a..099c90312227f 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -635,6 +635,7 @@ nullif(expression1, expression2) - [trim](#trim) - [upper](#upper) - [uuid](#uuid) +- [overlay](#overlay) ### `ascii` @@ -1120,6 +1121,22 @@ Returns UUID v4 string value which is unique per row. uuid() ``` +### `overlay` + +Returns the string which is replaced by another string from the specified position and specified count length. +For example, `overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas` + +``` +overlay(str PLACING substr FROM pos [FOR count]) +``` + +#### Arguments + +- **str**: String expression to operate on. +- **substr**: the string to replace part of str. +- **pos**: the start position to replace of str. +- **count**: the count of characters to be replaced from start position of str. If not specified, will use substr length instead. + ## Binary String Functions - [decode](#decode)