Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions be/src/vec/functions/function_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,7 @@ void register_function_string(SimpleFunctionFactory& factory) {
factory.register_function<FunctionSubReplace<SubReplaceFourImpl>>();
factory.register_function<FunctionOverlay>();
factory.register_function<FunctionStrcmp>();
factory.register_function<FunctionNgramSearch>();

factory.register_alias(FunctionLeft::name, "strleft");
factory.register_alias(FunctionRight::name, "strright");
Expand Down
127 changes: 127 additions & 0 deletions be/src/vec/functions/function_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "vec/columns/column.h"
#include "vec/columns/column_const.h"
#include "vec/columns/column_vector.h"
#include "vec/common/hash_table/phmap_fwd_decl.h"
#include "vec/common/int_exp.h"
#include "vec/common/memcmp_small.h"
#include "vec/common/memcpy_small.h"
Expand Down Expand Up @@ -3674,4 +3675,130 @@ class FunctionOverlay : public IFunction {
}
}
};

class FunctionNgramSearch : public IFunction {
public:
static constexpr auto name = "ngram_search";
static FunctionPtr create() { return std::make_shared<FunctionNgramSearch>(); }
String get_name() const override { return name; }
size_t get_number_of_arguments() const override { return 3; }
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
return std::make_shared<DataTypeFloat64>();
}

// ngram_search(text,pattern,gram_num)
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
CHECK_EQ(arguments.size(), 3);
auto col_res = ColumnFloat64::create();
bool col_const[3];
ColumnPtr argument_columns[3];
for (int i = 0; i < 3; ++i) {
std::tie(argument_columns[i], col_const[i]) =
unpack_if_const(block.get_by_position(arguments[i]).column);
}
// There is no need to check if the 2-th,3-th parameters are const here because fe has already checked them.
auto pattern = assert_cast<const ColumnString*>(argument_columns[1].get())->get_data_at(0);
auto gram_num = assert_cast<const ColumnInt32*>(argument_columns[2].get())->get_element(0);
const auto* text_col = assert_cast<const ColumnString*>(argument_columns[0].get());

if (col_const[0]) {
_execute_impl<true>(text_col, pattern, gram_num, *col_res, input_rows_count);
} else {
_execute_impl<false>(text_col, pattern, gram_num, *col_res, input_rows_count);
}

block.replace_by_position(result, std::move(col_res));
return Status::OK();
}

private:
using NgramMap = phmap::flat_hash_map<uint32_t, uint8_t>;
// In the map, the key is the CRC32 hash result of a substring in the string,
// and the value indicates whether this hash is found in the text or pattern.
constexpr static auto not_found = 0b00;
constexpr static auto found_in_pattern = 0b01;
constexpr static auto found_in_text = 0b10;
constexpr static auto found_in_pattern_and_text = 0b11;

uint32_t sub_str_hash(const char* data, int32_t length) const {
constexpr static uint32_t seed = 0;
return HashUtil::crc_hash(data, length, seed);
}

template <bool column_const>
void _execute_impl(const ColumnString* text_col, StringRef& pattern, int gram_num,
ColumnFloat64& res, size_t size) const {
auto& res_data = res.get_data();
res_data.resize_fill(size, 0);
// If the length of the pattern is less than gram_num, return 0.
if (pattern.size < gram_num) {
return;
}

// Build a map by pattern string, which will be used repeatedly in the following loop.
NgramMap pattern_map;
int pattern_count = get_pattern_set(pattern_map, pattern, gram_num);
// Each time a loop is executed, the map will be modified, so it needs to be restored afterward.
std::vector<uint32_t> restore_map;

for (int i = 0; i < size; i++) {
auto text = text_col->get_data_at(index_check_const<column_const>(i));
if (text.size < gram_num) {
// If the length of the text is less than gram_num, return 0.
continue;
}
restore_map.reserve(text.size);
auto [text_count, intersection_count] =
get_text_set(text, gram_num, pattern_map, restore_map);

// 2 * |Intersection| / (|text substr set| + |pattern substr set|)
res_data[i] = 2.0 * intersection_count / (text_count + pattern_count);
}
}

size_t get_pattern_set(NgramMap& pattern_map, StringRef& pattern, int gram_num) const {
size_t pattern_count = 0;
for (int i = 0; i + gram_num <= pattern.size; i++) {
uint32_t cur_hash = sub_str_hash(pattern.data + i, gram_num);
if (!pattern_map.contains(cur_hash)) {
pattern_map[cur_hash] = found_in_pattern;
pattern_count++;
}
}
return pattern_count;
}

pair<size_t, size_t> get_text_set(StringRef& text, int gram_num, NgramMap& pattern_map,
std::vector<uint32_t>& restore_map) const {
restore_map.clear();
//intersection_count indicates a substring both in pattern and text.
size_t text_count = 0, intersection_count = 0;
for (int i = 0; i + gram_num <= text.size; i++) {
uint32_t cur_hash = sub_str_hash(text.data + i, gram_num);
auto& val = pattern_map[cur_hash];
if (val == not_found) {
val ^= found_in_text;
DCHECK(val == found_in_text);
// only found in text
text_count++;
restore_map.push_back(cur_hash);
} else if (val == found_in_pattern) {
val ^= found_in_text;
DCHECK(val == found_in_pattern_and_text);
// found in text and pattern
text_count++;
intersection_count++;
restore_map.push_back(cur_hash);
}
}
// Restore the pattern_map.
for (auto& restore_hash : restore_map) {
pattern_map[restore_hash] ^= found_in_text;
}

return {text_count, intersection_count};
}
};

} // namespace doris::vectorized
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.MurmurHash332;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MurmurHash364;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Negative;
import org.apache.doris.nereids.trees.expressions.functions.scalar.NgramSearch;
import org.apache.doris.nereids.trees.expressions.functions.scalar.NonNullable;
import org.apache.doris.nereids.trees.expressions.functions.scalar.NotNullOrEmpty;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Now;
Expand Down Expand Up @@ -779,6 +780,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(Negative.class, "negative"),
scalar(NonNullable.class, "non_nullable"),
scalar(NotNullOrEmpty.class, "not_null_or_empty"),
scalar(NgramSearch.class, "ngram_search"),
scalar(Now.class, "now", "current_timestamp", "localtime", "localtimestamp"),
scalar(Nullable.class, "nullable"),
scalar(NullIf.class, "nullif"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.StringType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* ScalarFunction 'NgramSearch'.
*/
public class NgramSearch extends ScalarFunction
implements ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(StringType.INSTANCE, StringType.INSTANCE,
IntegerType.INSTANCE));

/**
* constructor with 3 argument.
*/
public NgramSearch(Expression arg0, Expression arg1, Expression arg2) {
super("ngram_search", arg0, arg1, arg2);
if (!(arg1.isConstant())) {
throw new AnalysisException(
"ngram_search(text,pattern,gram_num): pattern support const value only.");
}
if (!(arg2.isConstant())) {
throw new AnalysisException(
"ngram_search(text,pattern,gram_num): gram_num support const value only.");
}
}

/**
* withChildren.
*/
@Override
public NgramSearch withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 3);
return new NgramSearch(children.get(0), children.get(1), children.get(2));
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitNgramSearch(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.MurmurHash332;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MurmurHash364;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Negative;
import org.apache.doris.nereids.trees.expressions.functions.scalar.NgramSearch;
import org.apache.doris.nereids.trees.expressions.functions.scalar.NotNullOrEmpty;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Now;
import org.apache.doris.nereids.trees.expressions.functions.scalar.NullIf;
Expand Down Expand Up @@ -1608,6 +1609,10 @@ default R visitNegative(Negative negative, C context) {
return visitScalarFunction(negative, context);
}

default R visitNgramSearch(NgramSearch ngramSearch, C context) {
return visitScalarFunction(ngramSearch, context);
}

default R visitNotNullOrEmpty(NotNullOrEmpty notNullOrEmpty, C context) {
return visitScalarFunction(notNullOrEmpty, context);
}
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -361,4 +361,33 @@ suite("test_string_function", "arrow_flight_sql") {
qt_strcmp1 """ select strcmp('a', 'abc'); """
qt_strcmp2 """ select strcmp('abc', 'abc'); """
qt_strcmp3 """ select strcmp('abcd', 'abc'); """

sql "drop table if exists test_function_ngram_search;";
sql """ create table test_function_ngram_search (
k1 int not null,
s string null
) distributed by hash (k1) buckets 1
properties ("replication_num"="1");
"""

sql """ insert into test_function_ngram_search values(1,"fffhhhkkkk"),(2,"abc1313131"),(3,'1313131') ,(4,'abc') , (5,null)"""

qt_ngram_search1 """ select k1, ngram_search(s,'abc1313131',3) as x , s from test_function_ngram_search order by x ;"""

qt_ngram_search2 """select ngram_search('abc','abc1313131',3); """
qt_ngram_search3 """select ngram_search('abc1313131','abc1313131',3); """
qt_ngram_search3 """select ngram_search('1313131','abc1313131',3); """


sql "drop table if exists test_function_ngram_search;";
sql """ create table test_function_ngram_search (
k1 int not null,
s string not null
) distributed by hash (k1) buckets 1
properties ("replication_num"="1");
"""

sql """ insert into test_function_ngram_search values(1,"fffhhhkkkk"),(2,"abc1313131"),(3,'1313131') ,(4,'abc') """

qt_ngram_search1_not_null """ select k1, ngram_search(s,'abc1313131',3) as x , s from test_function_ngram_search order by x ;"""
}