Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 157 additions & 3 deletions be/src/vec/functions/function_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <math.h>
#include <re2/stringpiece.h>

#include <bitset>
#include <cstddef>
#include <string_view>

Expand Down Expand Up @@ -507,6 +508,15 @@ struct NameLTrim {
struct NameRTrim {
static constexpr auto name = "rtrim";
};
struct NameTrimIn {
static constexpr auto name = "trim_in";
};
struct NameLTrimIn {
static constexpr auto name = "ltrim_in";
};
struct NameRTrimIn {
static constexpr auto name = "rtrim_in";
};
template <bool is_ltrim, bool is_rtrim, bool trim_single>
struct TrimUtil {
static Status vector(const ColumnString::Chars& str_data,
Expand Down Expand Up @@ -534,6 +544,135 @@ struct TrimUtil {
return Status::OK();
}
};
template <bool is_ltrim, bool is_rtrim, bool trim_single>
struct TrimInUtil {
static Status vector(const ColumnString::Chars& str_data,
const ColumnString::Offsets& str_offsets, const StringRef& remove_str,
ColumnString::Chars& res_data, ColumnString::Offsets& res_offsets) {
const size_t offset_size = str_offsets.size();
res_offsets.resize(offset_size);
res_data.reserve(str_data.size());
bool all_ascii = simd::VStringFunctions::is_ascii(remove_str) &&
simd::VStringFunctions::is_ascii(StringRef(
reinterpret_cast<const char*>(str_data.data()), str_data.size()));

if (all_ascii) {
return impl_vectors_ascii(str_data, str_offsets, remove_str, res_data, res_offsets);
} else {
return impl_vectors_utf8(str_data, str_offsets, remove_str, res_data, res_offsets);
}
}

private:
static Status impl_vectors_ascii(const ColumnString::Chars& str_data,
const ColumnString::Offsets& str_offsets,
const StringRef& remove_str, ColumnString::Chars& res_data,
ColumnString::Offsets& res_offsets) {
const size_t offset_size = str_offsets.size();
std::bitset<128> char_lookup;
const char* remove_begin = remove_str.data;
const char* remove_end = remove_str.data + remove_str.size;

while (remove_begin < remove_end) {
char_lookup.set(static_cast<unsigned char>(*remove_begin));
remove_begin += 1;
}

for (size_t i = 0; i < offset_size; ++i) {
const char* str_begin =
reinterpret_cast<const char*>(str_data.data() + str_offsets[i - 1]);
const char* str_end = reinterpret_cast<const char*>(str_data.data() + str_offsets[i]);
const char* left_trim_pos = str_begin;
const char* right_trim_pos = str_end;

if constexpr (is_ltrim) {
while (left_trim_pos < str_end) {
if (!char_lookup.test(static_cast<unsigned char>(*left_trim_pos))) {
break;
}
++left_trim_pos;
}
}

if constexpr (is_rtrim) {
while (right_trim_pos > left_trim_pos) {
--right_trim_pos;
if (!char_lookup.test(static_cast<unsigned char>(*right_trim_pos))) {
++right_trim_pos;
break;
}
}
}

res_data.insert_assume_reserved(left_trim_pos, right_trim_pos);
res_offsets[i] = res_data.size();
}

return Status::OK();
}

static Status impl_vectors_utf8(const ColumnString::Chars& str_data,
const ColumnString::Offsets& str_offsets,
const StringRef& remove_str, ColumnString::Chars& res_data,
ColumnString::Offsets& res_offsets) {
const size_t offset_size = str_offsets.size();
res_offsets.resize(offset_size);
res_data.reserve(str_data.size());

std::unordered_set<std::string_view> char_lookup;
const char* remove_begin = remove_str.data;
const char* remove_end = remove_str.data + remove_str.size;

while (remove_begin < remove_end) {
size_t byte_len, char_len;
std::tie(byte_len, char_len) = simd::VStringFunctions::iterate_utf8_with_limit_length(
remove_begin, remove_end, 1);
char_lookup.insert(std::string_view(remove_begin, byte_len));
remove_begin += byte_len;
}

for (size_t i = 0; i < offset_size; ++i) {
const char* str_begin =
reinterpret_cast<const char*>(str_data.data() + str_offsets[i - 1]);
const char* str_end = reinterpret_cast<const char*>(str_data.data() + str_offsets[i]);
const char* left_trim_pos = str_begin;
const char* right_trim_pos = str_end;

if constexpr (is_ltrim) {
while (left_trim_pos < str_end) {
size_t byte_len, char_len;
std::tie(byte_len, char_len) =
simd::VStringFunctions::iterate_utf8_with_limit_length(left_trim_pos,
str_end, 1);
if (char_lookup.find(std::string_view(left_trim_pos, byte_len)) ==
char_lookup.end()) {
break;
}
left_trim_pos += byte_len;
}
}

if constexpr (is_rtrim) {
while (right_trim_pos > left_trim_pos) {
const char* prev_char_pos = right_trim_pos;
do {
--prev_char_pos;
} while ((*prev_char_pos & 0xC0) == 0x80);
size_t byte_len = right_trim_pos - prev_char_pos;
if (char_lookup.find(std::string_view(prev_char_pos, byte_len)) ==
char_lookup.end()) {
break;
}
right_trim_pos = prev_char_pos;
}
}

res_data.insert_assume_reserved(left_trim_pos, right_trim_pos);
res_offsets[i] = res_data.size();
}
return Status::OK();
}
};
// This is an implementation of a parameter for the Trim function.
template <bool is_ltrim, bool is_rtrim, typename Name>
struct Trim1Impl {
Expand Down Expand Up @@ -582,14 +721,23 @@ struct Trim2Impl {
const auto* remove_str_raw = col_right->get_chars().data();
const ColumnString::Offset remove_str_size = col_right->get_offsets()[0];
const StringRef remove_str(remove_str_raw, remove_str_size);

if (remove_str.size == 1) {
RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim, true>::vector(
col->get_chars(), col->get_offsets(), remove_str, col_res->get_chars(),
col_res->get_offsets())));
} else {
RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim, false>::vector(
col->get_chars(), col->get_offsets(), remove_str, col_res->get_chars(),
col_res->get_offsets())));
if constexpr (std::is_same<Name, NameTrimIn>::value ||
std::is_same<Name, NameLTrimIn>::value ||
std::is_same<Name, NameRTrimIn>::value) {
RETURN_IF_ERROR((TrimInUtil<is_ltrim, is_rtrim, false>::vector(
col->get_chars(), col->get_offsets(), remove_str,
col_res->get_chars(), col_res->get_offsets())));
} else {
RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim, false>::vector(
col->get_chars(), col->get_offsets(), remove_str,
col_res->get_chars(), col_res->get_offsets())));
}
}
block.replace_by_position(result, std::move(col_res));
} else {
Expand Down Expand Up @@ -1156,6 +1304,12 @@ void register_function_string(SimpleFunctionFactory& factory) {
factory.register_function<FunctionTrim<Trim2Impl<true, true, NameTrim>>>();
factory.register_function<FunctionTrim<Trim2Impl<true, false, NameLTrim>>>();
factory.register_function<FunctionTrim<Trim2Impl<false, true, NameRTrim>>>();
factory.register_function<FunctionTrim<Trim1Impl<true, true, NameTrimIn>>>();
factory.register_function<FunctionTrim<Trim1Impl<true, false, NameLTrimIn>>>();
factory.register_function<FunctionTrim<Trim1Impl<false, true, NameRTrimIn>>>();
factory.register_function<FunctionTrim<Trim2Impl<true, true, NameTrimIn>>>();
factory.register_function<FunctionTrim<Trim2Impl<true, false, NameLTrimIn>>>();
factory.register_function<FunctionTrim<Trim2Impl<false, true, NameRTrimIn>>>();
factory.register_function<FunctionConvertTo>();
factory.register_function<FunctionSubstring<Substr3Impl>>();
factory.register_function<FunctionSubstring<Substr2Impl>>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lower;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lpad;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Ltrim;
import org.apache.doris.nereids.trees.expressions.functions.scalar.LtrimIn;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MakeDate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsKey;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsValue;
Expand Down Expand Up @@ -346,6 +347,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.RoundBankers;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Rpad;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Rtrim;
import org.apache.doris.nereids.trees.expressions.functions.scalar.RtrimIn;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SecToTime;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Second;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SecondCeil;
Expand Down Expand Up @@ -424,6 +426,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Tokenize;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Translate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Trim;
import org.apache.doris.nereids.trees.expressions.functions.scalar.TrimIn;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Truncate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Unhex;
import org.apache.doris.nereids.trees.expressions.functions.scalar.UnixTimestamp;
Expand Down Expand Up @@ -734,6 +737,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(Lower.class, "lcase", "lower"),
scalar(Lpad.class, "lpad"),
scalar(Ltrim.class, "ltrim"),
scalar(LtrimIn.class, "ltrim_in"),
scalar(MakeDate.class, "makedate"),
scalar(MapContainsKey.class, "map_contains_key"),
scalar(MapContainsValue.class, "map_contains_value"),
Expand Down Expand Up @@ -808,6 +812,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(RoundBankers.class, "round_bankers"),
scalar(Rpad.class, "rpad"),
scalar(Rtrim.class, "rtrim"),
scalar(RtrimIn.class, "rtrim_in"),
scalar(Second.class, "second"),
scalar(SecondCeil.class, "second_ceil"),
scalar(SecondFloor.class, "second_floor"),
Expand Down Expand Up @@ -892,6 +897,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(ToQuantileState.class, "to_quantile_state"),
scalar(Translate.class, "translate"),
scalar(Trim.class, "trim"),
scalar(TrimIn.class, "trim_in"),
scalar(Truncate.class, "truncate"),
scalar(Unhex.class, "unhex"),
scalar(UnixTimestamp.class, "unix_timestamp"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,27 @@ private static String trimImpl(String first, String second, boolean left, boolea
return result;
}

private static String trimInImpl(String first, String second, boolean left, boolean right) {
StringBuilder result = new StringBuilder(first);

if (left) {
int start = 0;
while (start < result.length() && second.indexOf(result.charAt(start)) != -1) {
start++;
}
result.delete(0, start);
}
if (right) {
int end = result.length();
while (end > 0 && second.indexOf(result.charAt(end - 1)) != -1) {
end--;
}
result.delete(end, result.length());
}

return result.toString();
}

/**
* Executable arithmetic functions Trim
*/
Expand Down Expand Up @@ -200,6 +221,54 @@ public static Expression rtrimVarcharVarchar(StringLikeLiteral first, StringLike
return castStringLikeLiteral(first, trimImpl(first.getValue(), second.getValue(), false, true));
}

/**
* Executable arithmetic functions Trim_In
*/
@ExecFunction(name = "trim_in")
public static Expression trimInVarchar(StringLikeLiteral first) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), " ", true, true));
}

/**
* Executable arithmetic functions Trim_In
*/
@ExecFunction(name = "trim_in")
public static Expression trimInVarcharVarchar(StringLikeLiteral first, StringLikeLiteral second) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), second.getValue(), true, true));
}

/**
* Executable arithmetic functions ltrim_in
*/
@ExecFunction(name = "ltrim_in")
public static Expression ltrimInVarchar(StringLikeLiteral first) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), " ", true, false));
}

/**
* Executable arithmetic functions ltrim_in
*/
@ExecFunction(name = "ltrim_in")
public static Expression ltrimInVarcharVarchar(StringLikeLiteral first, StringLikeLiteral second) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), second.getValue(), true, false));
}

/**
* Executable arithmetic functions rtrim_in
*/
@ExecFunction(name = "rtrim_in")
public static Expression rtrimInVarchar(StringLikeLiteral first) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), " ", false, true));
}

/**
* Executable arithmetic functions rtrim_in
*/
@ExecFunction(name = "rtrim_in")
public static Expression rtrimInVarcharVarchar(StringLikeLiteral first, StringLikeLiteral second) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), second.getValue(), false, true));
}

/**
* Executable arithmetic functions Replace
*/
Expand Down
Loading