Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions be/src/olap/rowset/segment_v2/inverted_index_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@ class InvertedIndexQueryParamFactory {
M(PrimitiveType::TYPE_STRING)
M(PrimitiveType::TYPE_DATEV2)
M(PrimitiveType::TYPE_DATETIMEV2)
M(PrimitiveType::TYPE_IPV4)
M(PrimitiveType::TYPE_IPV6)
#undef M
default:
return Status::NotSupported("Unsupported primitive type {} for inverted index reader",
Expand Down
252 changes: 157 additions & 95 deletions be/src/vec/functions/function_ip.h
Original file line number Diff line number Diff line change
Expand Up @@ -652,91 +652,180 @@ class FunctionIsIPAddressInRange : public IFunction {
size_t get_number_of_arguments() const override { return 2; }

DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
if (arguments.size() != 2) {
throw Exception(
ErrorCode::INVALID_ARGUMENT,
"Number of arguments for function {} doesn't match: passed {}, should be 2",
get_name(), arguments.size());
}
const auto& addr_type = arguments[0];
const auto& cidr_type = arguments[1];
if (!is_string(remove_nullable(addr_type)) || !is_string(remove_nullable(cidr_type))) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"The arguments of function {} must be String", get_name());
return make_nullable(std::make_shared<DataTypeUInt8>());
}

template <PrimitiveType PT, typename ColumnType>
void execute_impl_with_ip(size_t input_rows_count, bool addr_const, bool cidr_const,
const ColumnString* str_cidr_column, const ColumnPtr addr_column,
ColumnUInt8* col_res) const {
auto& col_res_data = col_res->get_data();
const auto& ip_data = assert_cast<const ColumnType*>(addr_column.get())->get_data();
for (size_t i = 0; i < input_rows_count; ++i) {
auto addr_idx = index_check_const(i, addr_const);
auto cidr_idx = index_check_const(i, cidr_const);
const auto cidr =
parse_ip_with_cidr(str_cidr_column->get_data_at(cidr_idx).to_string_view());
if constexpr (PT == PrimitiveType::TYPE_IPV4) {
if (cidr._address.as_v4()) {
col_res_data[i] = match_ipv4_subnet(ip_data[addr_idx], cidr._address.as_v4(),
cidr._prefix)
? 1
: 0;
} else {
col_res_data[i] = 0;
}
} else if constexpr (PT == PrimitiveType::TYPE_IPV6) {
if (cidr._address.as_v6()) {
col_res_data[i] = match_ipv6_subnet((uint8*)(&ip_data[addr_idx]),
cidr._address.as_v6(), cidr._prefix)
? 1
: 0;
} else {
col_res_data[i] = 0;
}
}
}
return std::make_shared<DataTypeUInt8>();
}

bool use_default_implementation_for_nulls() const override { return false; }
Status evaluate_inverted_index(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: function 'evaluate_inverted_index' exceeds recommended size/complexity thresholds [readability-function-size]

    Status evaluate_inverted_index(
           ^
Additional context

be/src/vec/functions/function_ip.h:692: 90 lines including whitespace and comments (threshold 80)

    Status evaluate_inverted_index(
           ^

const ColumnsWithTypeAndName& arguments,
const std::vector<vectorized::IndexFieldNameAndTypePair>& data_type_with_names,
std::vector<segment_v2::InvertedIndexIterator*> iterators, uint32_t num_rows,
segment_v2::InvertedIndexResultBitmap& bitmap_result) const override {
DCHECK(arguments.size() == 1);
DCHECK(data_type_with_names.size() == 1);
DCHECK(iterators.size() == 1);
auto* iter = iterators[0];
auto data_type_with_name = data_type_with_names[0];
if (iter == nullptr) {
return Status::OK();
}

if (iter->get_inverted_index_reader_type() != segment_v2::InvertedIndexReaderType::BKD) {
// Not support only bkd index
return Status::Error<ErrorCode::INVERTED_INDEX_EVALUATE_SKIPPED>(
"Inverted index evaluate skipped, ip range reader can only support by bkd "
"reader");
}
// Get the is_ip_address_in_range from the arguments: cidr
const auto& cidr_column_with_type_and_name = arguments[0];
// in is_ip_address_in_range param is const Field
ColumnPtr arg_column = cidr_column_with_type_and_name.column;
DataTypePtr arg_type = cidr_column_with_type_and_name.type;
if ((is_column_nullable(*arg_column) && !is_column_const(*remove_nullable(arg_column))) ||
(!is_column_nullable(*arg_column) && !is_column_const(*arg_column))) {
// if not we should skip inverted index and evaluate in expression
return Status::Error<ErrorCode::INVERTED_INDEX_EVALUATE_SKIPPED>(
"Inverted index evaluate skipped, is_ip_address_in_range only support const "
"value");
}
// check param type is string
if (!WhichDataType(*arg_type).is_string()) {
return Status::Error<ErrorCode::INVERTED_INDEX_EVALUATE_SKIPPED>(
"Inverted index evaluate skipped, is_ip_address_in_range only support string "
"type");
}
// min && max ip address
Field min_ip, max_ip;
IPAddressCIDR cidr = parse_ip_with_cidr(arg_column->get_data_at(0));
if (WhichDataType(remove_nullable(data_type_with_name.second)).is_ipv4() &&
cidr._address.as_v4()) {
auto range = apply_cidr_mask(cidr._address.as_v4(), cidr._prefix);
min_ip = range.first;
max_ip = range.second;
} else if (WhichDataType(remove_nullable(data_type_with_name.second)).is_ipv6() &&
cidr._address.as_v6()) {
auto cidr_range_ipv6_col = ColumnIPv6::create(2, 0);
auto& cidr_range_ipv6_data = cidr_range_ipv6_col->get_data();
apply_cidr_mask(reinterpret_cast<const char*>(cidr._address.as_v6()),
reinterpret_cast<char*>(&cidr_range_ipv6_data[0]),
reinterpret_cast<char*>(&cidr_range_ipv6_data[1]), cidr._prefix);
min_ip = cidr_range_ipv6_data[0];
max_ip = cidr_range_ipv6_data[1];
}
// apply for inverted index
std::shared_ptr<roaring::Roaring> res_roaring = std::make_shared<roaring::Roaring>();
std::shared_ptr<roaring::Roaring> max_roaring = std::make_shared<roaring::Roaring>();
std::shared_ptr<roaring::Roaring> null_bitmap = std::make_shared<roaring::Roaring>();

auto param_type = data_type_with_name.second->get_type_as_type_descriptor().type;
std::unique_ptr<segment_v2::InvertedIndexQueryParamFactory> query_param = nullptr;
// >= min ip
RETURN_IF_ERROR(segment_v2::InvertedIndexQueryParamFactory::create_query_value(
param_type, &min_ip, query_param));
RETURN_IF_ERROR(iter->read_from_inverted_index(
data_type_with_name.first, query_param->get_value(),
segment_v2::InvertedIndexQueryType::GREATER_EQUAL_QUERY, num_rows, res_roaring));
// <= max ip
RETURN_IF_ERROR(segment_v2::InvertedIndexQueryParamFactory::create_query_value(
param_type, &max_ip, query_param));
RETURN_IF_ERROR(iter->read_from_inverted_index(
data_type_with_name.first, query_param->get_value(),
segment_v2::InvertedIndexQueryType::LESS_EQUAL_QUERY, num_rows, max_roaring));

DBUG_EXECUTE_IF("ip.inverted_index_filtered", {
auto req_id = DebugPoints::instance()->get_debug_param_or_default<int32_t>(
"ip.inverted_index_filtered", "req_id", 0);
LOG(INFO) << "execute inverted index req_id: " << req_id
<< " min: " << res_roaring->cardinality();
});
*res_roaring &= *max_roaring;
DBUG_EXECUTE_IF("ip.inverted_index_filtered", {
auto req_id = DebugPoints::instance()->get_debug_param_or_default<int32_t>(
"ip.inverted_index_filtered", "req_id", 0);
LOG(INFO) << "execute inverted index req_id: " << req_id
<< " max: " << max_roaring->cardinality()
<< " result: " << res_roaring->cardinality();
});
segment_v2::InvertedIndexResultBitmap result(res_roaring, null_bitmap);
bitmap_result = result;
bitmap_result.mask_out_null();
return Status::OK();
}

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
DBUG_EXECUTE_IF("ip.inverted_index_filtered", {
auto req_id = DebugPoints::instance()->get_debug_param_or_default<int32_t>(
"ip.inverted_index_filtered", "req_id", 0);
return Status::Error<ErrorCode::INTERNAL_ERROR>(
"{} has already execute inverted index req_id {} , should not execute expr "
"with rows: {}",
get_name(), req_id, input_rows_count);
});
const auto& addr_column_with_type_and_name = block.get_by_position(arguments[0]);
const auto& cidr_column_with_type_and_name = block.get_by_position(arguments[1]);
WhichDataType addr_type(addr_column_with_type_and_name.type);
WhichDataType cidr_type(cidr_column_with_type_and_name.type);
const auto& [addr_column, addr_const] =
unpack_if_const(addr_column_with_type_and_name.column);
const auto& [cidr_column, cidr_const] =
unpack_if_const(cidr_column_with_type_and_name.column);
const ColumnString* str_addr_column = nullptr;
const ColumnString* str_cidr_column = nullptr;
const NullMap* null_map_addr = nullptr;
const NullMap* null_map_cidr = nullptr;

if (addr_type.is_nullable()) {
const auto* addr_column_nullable =
assert_cast<const ColumnNullable*>(addr_column.get());
str_addr_column =
check_and_get_column<ColumnString>(addr_column_nullable->get_nested_column());
null_map_addr = &addr_column_nullable->get_null_map_data();
} else {
str_addr_column = check_and_get_column<ColumnString>(addr_column.get());
}

if (cidr_type.is_nullable()) {
const auto* cidr_column_nullable =
assert_cast<const ColumnNullable*>(cidr_column.get());
str_cidr_column =
check_and_get_column<ColumnString>(cidr_column_nullable->get_nested_column());
null_map_cidr = &cidr_column_nullable->get_null_map_data();
} else {
str_cidr_column = check_and_get_column<ColumnString>(cidr_column.get());
}

if (!str_addr_column) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"Illegal column {} of argument of function {}, expected String",
addr_column->get_name(), get_name());
}

if (!str_cidr_column) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"Illegal column {} of argument of function {}, expected String",
cidr_column->get_name(), get_name());
}

auto col_res = ColumnUInt8::create(input_rows_count, 0);
auto& col_res_data = col_res->get_data();

for (size_t i = 0; i < input_rows_count; ++i) {
auto addr_idx = index_check_const(i, addr_const);
auto cidr_idx = index_check_const(i, cidr_const);
if (null_map_addr && (*null_map_addr)[addr_idx]) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"The arguments of function {} must be String, not NULL",
get_name());
}
if (null_map_cidr && (*null_map_cidr)[cidr_idx]) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"The arguments of function {} must be String, not NULL",
get_name());
if (is_ipv4(addr_column_with_type_and_name.type)) {
execute_impl_with_ip<PrimitiveType::TYPE_IPV4, ColumnIPv4>(
input_rows_count, addr_const, cidr_const,
assert_cast<const ColumnString*>(cidr_column.get()), addr_column, col_res);
} else if (is_ipv6(addr_column_with_type_and_name.type)) {
execute_impl_with_ip<PrimitiveType::TYPE_IPV6, ColumnIPv6>(
input_rows_count, addr_const, cidr_const,
assert_cast<const ColumnString*>(cidr_column.get()), addr_column, col_res);
} else {
const auto* str_addr_column = assert_cast<const ColumnString*>(addr_column.get());
const auto* str_cidr_column = assert_cast<const ColumnString*>(cidr_column.get());

for (size_t i = 0; i < input_rows_count; ++i) {
auto addr_idx = index_check_const(i, addr_const);
auto cidr_idx = index_check_const(i, cidr_const);

const auto addr =
IPAddressVariant(str_addr_column->get_data_at(addr_idx).to_string_view());
const auto cidr =
parse_ip_with_cidr(str_cidr_column->get_data_at(cidr_idx).to_string_view());
col_res_data[i] = is_address_in_range(addr, cidr) ? 1 : 0;
}
const auto addr =
IPAddressVariant(str_addr_column->get_data_at(addr_idx).to_string_view());
const auto cidr =
parse_ip_with_cidr(str_cidr_column->get_data_at(cidr_idx).to_string_view());
col_res_data[i] = is_address_in_range(addr, cidr) ? 1 : 0;
}

block.replace_by_position(result, std::move(col_res));
Expand Down Expand Up @@ -839,21 +928,6 @@ class FunctionIPv4CIDRToRange : public IFunction {
std::move(col_upper_range_output)}));
return Status::OK();
}

private:
static inline std::pair<UInt32, UInt32> apply_cidr_mask(UInt32 src, UInt8 bits_to_keep) {
if (bits_to_keep >= 8 * sizeof(UInt32)) {
return {src, src};
}
if (bits_to_keep == 0) {
return {static_cast<UInt32>(0), static_cast<UInt32>(-1)};
}
UInt32 mask = static_cast<UInt32>(-1) << (8 * sizeof(UInt32) - bits_to_keep);
UInt32 lower = src & mask;
UInt32 upper = lower | ~mask;

return {lower, upper};
}
};

class FunctionIPv6CIDRToRange : public IFunction {
Expand Down Expand Up @@ -991,18 +1065,6 @@ class FunctionIPv6CIDRToRange : public IFunction {
return ColumnStruct::create(
Columns {std::move(col_res_lower_range), std::move(col_res_upper_range)});
}

private:
static void apply_cidr_mask(const char* __restrict src, char* __restrict dst_lower,
char* __restrict dst_upper, UInt8 bits_to_keep) {
// little-endian mask
const auto& mask = get_cidr_mask_ipv6(bits_to_keep);

for (int8_t i = IPV6_BINARY_LENGTH - 1; i >= 0; --i) {
dst_lower[i] = src[i] & mask[i];
dst_upper[i] = dst_lower[i] | ~mask[i];
}
}
};

class FunctionIsIPv4Compat : public IFunction {
Expand Down
27 changes: 27 additions & 0 deletions be/src/vec/runtime/ip_address_cidr.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,33 @@
#include "vec/common/format_ip.h"
namespace doris {

namespace vectorized {
static inline std::pair<UInt32, UInt32> apply_cidr_mask(UInt32 src, UInt8 bits_to_keep) {
if (bits_to_keep >= 8 * sizeof(UInt32)) {
return {src, src};
}
if (bits_to_keep == 0) {
return {static_cast<UInt32>(0), static_cast<UInt32>(-1)};
}
UInt32 mask = static_cast<UInt32>(-1) << (8 * sizeof(UInt32) - bits_to_keep);
UInt32 lower = src & mask;
UInt32 upper = lower | ~mask;

return {lower, upper};
}

static inline void apply_cidr_mask(const char* __restrict src, char* __restrict dst_lower,
char* __restrict dst_upper, UInt8 bits_to_keep) {
// little-endian mask
const auto& mask = get_cidr_mask_ipv6(bits_to_keep);

for (int8_t i = IPV6_BINARY_LENGTH - 1; i >= 0; --i) {
dst_lower[i] = src[i] & mask[i];
dst_upper[i] = dst_lower[i] | ~mask[i];
}
}
} // namespace vectorized

class IPAddressVariant {
public:
explicit IPAddressVariant(std::string_view address_str) {
Expand Down
10 changes: 5 additions & 5 deletions be/test/vec/function/function_ip_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ TEST(FunctionIpTest, FunctionIsIPAddressInRangeTest) {
{
// vector vs vector
InputTypeSet input_types = {TypeIndex::String, TypeIndex::String};
static_cast<void>(check_function<DataTypeUInt8, false>(func_name, input_types, data_set));
static_cast<void>(check_function<DataTypeUInt8, true>(func_name, input_types, data_set));
}

{
// vector vs scalar
InputTypeSet input_types = {TypeIndex::String, Consted {TypeIndex::String}};
for (const auto& line : data_set) {
DataSet const_cidr_dataset = {line};
static_cast<void>(check_function<DataTypeUInt8, false>(func_name, input_types,
const_cidr_dataset));
static_cast<void>(check_function<DataTypeUInt8, true>(func_name, input_types,
const_cidr_dataset));
}
}

Expand All @@ -75,8 +75,8 @@ TEST(FunctionIpTest, FunctionIsIPAddressInRangeTest) {
InputTypeSet input_types = {Consted {TypeIndex::String}, TypeIndex::String};
for (const auto& line : data_set) {
DataSet const_addr_dataset = {line};
static_cast<void>(check_function<DataTypeUInt8, false>(func_name, input_types,
const_addr_dataset));
static_cast<void>(check_function<DataTypeUInt8, true>(func_name, input_types,
const_addr_dataset));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.IPv4Type;
import org.apache.doris.nereids.types.IPv6Type;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.VarcharType;

Expand All @@ -36,9 +38,13 @@
* scalar function `is_ip_address_in_range`
*/
public class IsIpAddressInRange extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable {
implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE).args(IPv4Type.INSTANCE, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(BooleanType.INSTANCE).args(IPv4Type.INSTANCE, StringType.INSTANCE),
FunctionSignature.ret(BooleanType.INSTANCE).args(IPv6Type.INSTANCE, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(BooleanType.INSTANCE).args(IPv6Type.INSTANCE, StringType.INSTANCE),
FunctionSignature.ret(BooleanType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(BooleanType.INSTANCE).args(StringType.INSTANCE, StringType.INSTANCE));

Expand Down
Loading
Loading