Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 7 additions & 37 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,12 @@ pub enum BuiltinScalarFunction {
Lpad,
/// random
Random,
/// repeat
Repeat,
/// replace
Replace,
/// reverse
Reverse,
/// right
Right,
/// rpad
Rpad,
/// split_part
SplitPart,
/// strpos
Strpos,
/// substr
Expand Down Expand Up @@ -238,12 +232,9 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Left => Volatility::Immutable,
BuiltinScalarFunction::Lpad => Volatility::Immutable,
BuiltinScalarFunction::Radians => Volatility::Immutable,
BuiltinScalarFunction::Repeat => Volatility::Immutable,
BuiltinScalarFunction::Replace => Volatility::Immutable,
BuiltinScalarFunction::Reverse => Volatility::Immutable,
BuiltinScalarFunction::Right => Volatility::Immutable,
BuiltinScalarFunction::Rpad => Volatility::Immutable,
BuiltinScalarFunction::SplitPart => Volatility::Immutable,
BuiltinScalarFunction::Strpos => Volatility::Immutable,
BuiltinScalarFunction::Substr => Volatility::Immutable,
BuiltinScalarFunction::Translate => Volatility::Immutable,
Expand Down Expand Up @@ -293,22 +284,13 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"),
BuiltinScalarFunction::Pi => Ok(Float64),
BuiltinScalarFunction::Random => Ok(Float64),
BuiltinScalarFunction::Repeat => {
utf8_to_str_type(&input_expr_types[0], "repeat")
}
BuiltinScalarFunction::Replace => {
utf8_to_str_type(&input_expr_types[0], "replace")
}
BuiltinScalarFunction::Reverse => {
utf8_to_str_type(&input_expr_types[0], "reverse")
}
BuiltinScalarFunction::Right => {
utf8_to_str_type(&input_expr_types[0], "right")
}
BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"),
BuiltinScalarFunction::SplitPart => {
utf8_to_str_type(&input_expr_types[0], "split_part")
}
BuiltinScalarFunction::EndsWith => Ok(Boolean),
BuiltinScalarFunction::Strpos => {
utf8_to_int_type(&input_expr_types[0], "strpos/instr/position")
Expand Down Expand Up @@ -417,21 +399,12 @@ impl BuiltinScalarFunction {
self.volatility(),
)
}
BuiltinScalarFunction::Left
| BuiltinScalarFunction::Repeat
| BuiltinScalarFunction::Right => Signature::one_of(
vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])],
self.volatility(),
),
BuiltinScalarFunction::SplitPart => Signature::one_of(
vec![
Exact(vec![Utf8, Utf8, Int64]),
Exact(vec![LargeUtf8, Utf8, Int64]),
Exact(vec![Utf8, LargeUtf8, Int64]),
Exact(vec![LargeUtf8, LargeUtf8, Int64]),
],
self.volatility(),
),
BuiltinScalarFunction::Left | BuiltinScalarFunction::Right => {
Signature::one_of(
vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])],
self.volatility(),
)
}

BuiltinScalarFunction::EndsWith | BuiltinScalarFunction::Strpos => {
Signature::one_of(
Expand Down Expand Up @@ -467,7 +440,7 @@ impl BuiltinScalarFunction {
self.volatility(),
),

BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => {
BuiltinScalarFunction::Translate => {
Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility())
}
BuiltinScalarFunction::Pi => Signature::exact(vec![], self.volatility()),
Expand Down Expand Up @@ -637,12 +610,9 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::InitCap => &["initcap"],
BuiltinScalarFunction::Left => &["left"],
BuiltinScalarFunction::Lpad => &["lpad"],
BuiltinScalarFunction::Repeat => &["repeat"],
BuiltinScalarFunction::Replace => &["replace"],
BuiltinScalarFunction::Reverse => &["reverse"],
BuiltinScalarFunction::Right => &["right"],
BuiltinScalarFunction::Rpad => &["rpad"],
BuiltinScalarFunction::SplitPart => &["split_part"],
BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"],
BuiltinScalarFunction::Substr => &["substr"],
BuiltinScalarFunction::Translate => &["translate"],
Expand Down
6 changes: 0 additions & 6 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,11 +598,8 @@ scalar_expr!(
);
scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase");
scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`");
scalar_expr!(Replace, replace, string from to, "replaces all occurrences of `from` with `to` in the `string`");
scalar_expr!(Repeat, repeat, string n, "repeats the `string` to `n` times");
scalar_expr!(Reverse, reverse, string, "reverses the `string`");
scalar_expr!(Right, right, string n, "returns the last `n` characters in the `string`");
scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index.");
scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`");
scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`");
scalar_expr!(Substr, substr, string position, "substring from the `position` to the end");
Expand Down Expand Up @@ -1056,13 +1053,10 @@ mod test {
test_scalar_expr!(Left, left, string, count);
test_nary_scalar_expr!(Lpad, lpad, string, count);
test_nary_scalar_expr!(Lpad, lpad, string, count, characters);
test_scalar_expr!(Replace, replace, string, from, to);
test_scalar_expr!(Repeat, repeat, string, count);
test_scalar_expr!(Reverse, reverse, string);
test_scalar_expr!(Right, right, string, count);
test_nary_scalar_expr!(Rpad, rpad, string, count);
test_nary_scalar_expr!(Rpad, rpad, string, count, characters);
test_scalar_expr!(SplitPart, split_part, expr, delimiter, index);
test_scalar_expr!(EndsWith, ends_with, string, characters);
test_scalar_expr!(Strpos, strpos, string, substring);
test_scalar_expr!(Substr, substr, string, position);
Expand Down
24 changes: 24 additions & 0 deletions datafusion/functions/src/string/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ mod lower;
mod ltrim;
mod octet_length;
mod overlay;
mod repeat;
mod replace;
mod rtrim;
mod split_part;
mod starts_with;
mod to_hex;
mod upper;
Expand All @@ -43,8 +46,11 @@ make_udf_function!(ltrim::LtrimFunc, LTRIM, ltrim);
make_udf_function!(lower::LowerFunc, LOWER, lower);
make_udf_function!(octet_length::OctetLengthFunc, OCTET_LENGTH, octet_length);
make_udf_function!(overlay::OverlayFunc, OVERLAY, overlay);
make_udf_function!(repeat::RepeatFunc, REPEAT, repeat);
make_udf_function!(replace::ReplaceFunc, REPLACE, replace);
make_udf_function!(rtrim::RtrimFunc, RTRIM, rtrim);
make_udf_function!(starts_with::StartsWithFunc, STARTS_WITH, starts_with);
make_udf_function!(split_part::SplitPartFunc, SPLIT_PART, split_part);
make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex);
make_udf_function!(upper::UpperFunc, UPPER, upper);
make_udf_function!(uuid::UuidFunc, UUID, uuid);
Expand Down Expand Up @@ -87,11 +93,26 @@ pub mod expr_fn {
super::overlay().call(args)
}

#[doc = "Repeats the `string` to `n` times"]
pub fn repeat(string: Expr, n: Expr) -> Expr {
super::repeat().call(vec![string, n])
}

#[doc = "Replaces all occurrences of `from` with `to` in the `string`"]
pub fn replace(string: Expr, from: Expr, to: Expr) -> Expr {
super::replace().call(vec![string, from, to])
}

#[doc = "Removes all characters, spaces by default, from the end of a string"]
pub fn rtrim(args: Vec<Expr>) -> Expr {
super::rtrim().call(args)
}

#[doc = "Splits a string based on a delimiter and picks out the desired field based on the index."]
pub fn split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr {
super::split_part().call(vec![string, delimiter, index])
}

#[doc = "Returns true if string starts with prefix."]
pub fn starts_with(arg1: Expr, arg2: Expr) -> Expr {
super::starts_with().call(vec![arg1, arg2])
Expand Down Expand Up @@ -128,7 +149,10 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
ltrim(),
octet_length(),
overlay(),
repeat(),
replace(),
rtrim(),
split_part(),
starts_with(),
to_hex(),
upper(),
Expand Down
144 changes: 144 additions & 0 deletions datafusion/functions/src/string/repeat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::any::Any;
use std::sync::Arc;

use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
use arrow::datatypes::DataType;

use datafusion_common::cast::{as_generic_string_array, as_int64_array};
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ColumnarValue, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};

use crate::string::common::*;

#[derive(Debug)]
pub(super) struct RepeatFunc {
signature: Signature,
}

impl RepeatFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for RepeatFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"repeat"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
utf8_to_str_type(&arg_types[0], "repeat")
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => make_scalar_function(repeat::<i32>, vec![])(args),
DataType::LargeUtf8 => make_scalar_function(repeat::<i64>, vec![])(args),
other => exec_err!("Unsupported data type {other:?} for function repeat"),
}
}
}

/// Repeats string the specified number of times.
/// repeat('Pg', 4) = 'PgPgPgPg'
fn repeat<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = as_generic_string_array::<T>(&args[0])?;
let number_array = as_int64_array(&args[1])?;

let result = string_array
.iter()
.zip(number_array.iter())
.map(|(string, number)| match (string, number) {
(Some(string), Some(number)) => Some(string.repeat(number as usize)),
_ => None,
})
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}

#[cfg(test)]
mod tests {
use arrow::array::{Array, StringArray};
use arrow::datatypes::DataType::Utf8;

use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};

use crate::string::common::test::test_function;
use crate::string::repeat::RepeatFunc;

#[test]
fn test_functions() -> Result<()> {
test_function!(
RepeatFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
],
Ok(Some("PgPgPgPg")),
&str,
Utf8,
StringArray
);

test_function!(
RepeatFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::Utf8(None)),
ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
],
Ok(None),
&str,
Utf8,
StringArray
);
test_function!(
RepeatFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
ColumnarValue::Scalar(ScalarValue::Int64(None)),
],
Ok(None),
&str,
Utf8,
StringArray
);

Ok(())
}
}
Loading