Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions be/src/olap/rowset/segment_v2/inverted_index_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,8 @@ class InvertedIndexQueryParamFactory {
M(PrimitiveType::TYPE_STRING)
M(PrimitiveType::TYPE_DATEV2)
M(PrimitiveType::TYPE_DATETIMEV2)
M(PrimitiveType::TYPE_IPV4)
M(PrimitiveType::TYPE_IPV6)
#undef M
default:
return Status::NotSupported("Unsupported primitive type {} for inverted index reader",
Expand Down
252 changes: 146 additions & 106 deletions be/src/vec/functions/function_ip.h
Original file line number Diff line number Diff line change
Expand Up @@ -606,110 +606,177 @@ class FunctionIsIPAddressInRange : public IFunction {
return std::make_shared<DataTypeUInt8>();
}

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
const auto& addr_column_with_type_and_name = block.get_by_position(arguments[0]);
const auto& cidr_column_with_type_and_name = block.get_by_position(arguments[1]);
WhichDataType addr_type(addr_column_with_type_and_name.type);
WhichDataType cidr_type(cidr_column_with_type_and_name.type);
const auto& [addr_column, addr_const] =
unpack_if_const(addr_column_with_type_and_name.column);
const auto& [cidr_column, cidr_const] =
unpack_if_const(cidr_column_with_type_and_name.column);
const auto* str_addr_column = assert_cast<const ColumnString*>(addr_column.get());
const auto* str_cidr_column = assert_cast<const ColumnString*>(cidr_column.get());

auto col_res = ColumnUInt8::create(input_rows_count, 0);
template <PrimitiveType PT, typename ColumnType>
void execute_impl_with_ip(size_t input_rows_count, bool addr_const, bool cidr_const,
const ColumnString* str_cidr_column, const ColumnPtr addr_column,
ColumnUInt8* col_res) const {
auto& col_res_data = col_res->get_data();

const auto& ip_data = assert_cast<const ColumnType*>(addr_column.get())->get_data();
for (size_t i = 0; i < input_rows_count; ++i) {
auto addr_idx = index_check_const(i, addr_const);
auto cidr_idx = index_check_const(i, cidr_const);

const auto addr =
IPAddressVariant(str_addr_column->get_data_at(addr_idx).to_string_view());
const auto cidr =
parse_ip_with_cidr(str_cidr_column->get_data_at(cidr_idx).to_string_view());
col_res_data[i] = is_address_in_range(addr, cidr) ? 1 : 0;
if constexpr (PT == PrimitiveType::TYPE_IPV4) {
if (cidr._address.as_v4()) {
col_res_data[i] = match_ipv4_subnet(ip_data[addr_idx], cidr._address.as_v4(),
cidr._prefix)
? 1
: 0;
} else {
col_res_data[i] = 0;
}
} else if constexpr (PT == PrimitiveType::TYPE_IPV6) {
if (cidr._address.as_v6()) {
col_res_data[i] = match_ipv6_subnet((uint8*)(&ip_data[addr_idx]),
cidr._address.as_v6(), cidr._prefix)
? 1
: 0;
} else {
col_res_data[i] = 0;
}
}
}

block.replace_by_position(result, std::move(col_res));
return Status::OK();
}
};

// old version throw exception when meet null value
class FunctionIsIPAddressInRangeOld : public IFunction {
public:
static constexpr auto name = "is_ip_address_in_range";
static FunctionPtr create() { return std::make_shared<FunctionIsIPAddressInRange>(); }

String get_name() const override { return name; }

size_t get_number_of_arguments() const override { return 2; }
Status evaluate_inverted_index(
const ColumnsWithTypeAndName& arguments,
const std::vector<vectorized::IndexFieldNameAndTypePair>& data_type_with_names,
std::vector<segment_v2::InvertedIndexIterator*> iterators, uint32_t num_rows,
segment_v2::InvertedIndexResultBitmap& bitmap_result) const override {
DCHECK(arguments.size() == 1);
DCHECK(data_type_with_names.size() == 1);
DCHECK(iterators.size() == 1);
auto* iter = iterators[0];
auto data_type_with_name = data_type_with_names[0];
if (iter == nullptr) {
return Status::OK();
}

DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
return std::make_shared<DataTypeUInt8>();
if (iter->get_inverted_index_reader_type() != segment_v2::InvertedIndexReaderType::BKD) {
// Not support only bkd index
return Status::Error<ErrorCode::INVERTED_INDEX_EVALUATE_SKIPPED>(
"Inverted index evaluate skipped, ip range reader can only support by bkd "
"reader");
}
// Get the is_ip_address_in_range from the arguments: cidr
const auto& cidr_column_with_type_and_name = arguments[0];
// in is_ip_address_in_range param is const Field
ColumnPtr arg_column = cidr_column_with_type_and_name.column;
DataTypePtr arg_type = cidr_column_with_type_and_name.type;
if ((is_column_nullable(*arg_column) && !is_column_const(*remove_nullable(arg_column))) ||
(!is_column_nullable(*arg_column) && !is_column_const(*arg_column))) {
// if not we should skip inverted index and evaluate in expression
return Status::Error<ErrorCode::INVERTED_INDEX_EVALUATE_SKIPPED>(
"Inverted index evaluate skipped, is_ip_address_in_range only support const "
"value");
}
// check param type is string
if (!WhichDataType(*arg_type).is_string()) {
return Status::Error<ErrorCode::INVERTED_INDEX_EVALUATE_SKIPPED>(
"Inverted index evaluate skipped, is_ip_address_in_range only support string "
"type");
}
// min && max ip address
Field min_ip, max_ip;
IPAddressCIDR cidr = parse_ip_with_cidr(arg_column->get_data_at(0));
if (WhichDataType(remove_nullable(data_type_with_name.second)).is_ipv4() &&
cidr._address.as_v4()) {
auto range = apply_cidr_mask(cidr._address.as_v4(), cidr._prefix);
min_ip = range.first;
max_ip = range.second;
} else if (WhichDataType(remove_nullable(data_type_with_name.second)).is_ipv6() &&
cidr._address.as_v6()) {
auto cidr_range_ipv6_col = ColumnIPv6::create(2, 0);
auto& cidr_range_ipv6_data = cidr_range_ipv6_col->get_data();
apply_cidr_mask(reinterpret_cast<const char*>(cidr._address.as_v6()),
reinterpret_cast<char*>(&cidr_range_ipv6_data[0]),
reinterpret_cast<char*>(&cidr_range_ipv6_data[1]), cidr._prefix);
min_ip = cidr_range_ipv6_data[0];
max_ip = cidr_range_ipv6_data[1];
}
// apply for inverted index
std::shared_ptr<roaring::Roaring> res_roaring = std::make_shared<roaring::Roaring>();
std::shared_ptr<roaring::Roaring> max_roaring = std::make_shared<roaring::Roaring>();
std::shared_ptr<roaring::Roaring> null_bitmap = std::make_shared<roaring::Roaring>();

auto param_type = data_type_with_name.second->get_type_as_type_descriptor().type;
std::unique_ptr<segment_v2::InvertedIndexQueryParamFactory> query_param = nullptr;
// >= min ip
RETURN_IF_ERROR(segment_v2::InvertedIndexQueryParamFactory::create_query_value(
param_type, &min_ip, query_param));
RETURN_IF_ERROR(iter->read_from_inverted_index(
data_type_with_name.first, query_param->get_value(),
segment_v2::InvertedIndexQueryType::GREATER_EQUAL_QUERY, num_rows, res_roaring));
// <= max ip
RETURN_IF_ERROR(segment_v2::InvertedIndexQueryParamFactory::create_query_value(
param_type, &max_ip, query_param));
RETURN_IF_ERROR(iter->read_from_inverted_index(
data_type_with_name.first, query_param->get_value(),
segment_v2::InvertedIndexQueryType::LESS_EQUAL_QUERY, num_rows, max_roaring));

DBUG_EXECUTE_IF("ip.inverted_index_filtered", {
auto req_id = DebugPoints::instance()->get_debug_param_or_default<int32_t>(
"ip.inverted_index_filtered", "req_id", 0);
LOG(INFO) << "execute inverted index req_id: " << req_id
<< " min: " << res_roaring->cardinality();
});
*res_roaring &= *max_roaring;
DBUG_EXECUTE_IF("ip.inverted_index_filtered", {
auto req_id = DebugPoints::instance()->get_debug_param_or_default<int32_t>(
"ip.inverted_index_filtered", "req_id", 0);
LOG(INFO) << "execute inverted index req_id: " << req_id
<< " max: " << max_roaring->cardinality()
<< " result: " << res_roaring->cardinality();
});
segment_v2::InvertedIndexResultBitmap result(res_roaring, null_bitmap);
bitmap_result = result;
bitmap_result.mask_out_null();
return Status::OK();
}

bool use_default_implementation_for_nulls() const override { return false; }

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
DBUG_EXECUTE_IF("ip.inverted_index_filtered", {
auto req_id = DebugPoints::instance()->get_debug_param_or_default<int32_t>(
"ip.inverted_index_filtered", "req_id", 0);
return Status::Error<ErrorCode::INTERNAL_ERROR>(
"{} has already execute inverted index req_id {} , should not execute expr "
"with rows: {}",
get_name(), req_id, input_rows_count);
});
const auto& addr_column_with_type_and_name = block.get_by_position(arguments[0]);
const auto& cidr_column_with_type_and_name = block.get_by_position(arguments[1]);
WhichDataType addr_type(addr_column_with_type_and_name.type);
WhichDataType cidr_type(cidr_column_with_type_and_name.type);
const auto& [addr_column, addr_const] =
unpack_if_const(addr_column_with_type_and_name.column);
const auto& [cidr_column, cidr_const] =
unpack_if_const(cidr_column_with_type_and_name.column);
const ColumnString* str_addr_column = nullptr;
const ColumnString* str_cidr_column = nullptr;
const NullMap* null_map_addr = nullptr;
const NullMap* null_map_cidr = nullptr;

if (addr_type.is_nullable()) {
const auto* addr_column_nullable =
assert_cast<const ColumnNullable*>(addr_column.get());
str_addr_column = assert_cast<const ColumnString*>(
addr_column_nullable->get_nested_column_ptr().get());
null_map_addr = &addr_column_nullable->get_null_map_data();
} else {
str_addr_column = assert_cast<const ColumnString*>(addr_column.get());
}

if (cidr_type.is_nullable()) {
const auto* cidr_column_nullable =
assert_cast<const ColumnNullable*>(cidr_column.get());
str_cidr_column = assert_cast<const ColumnString*>(
cidr_column_nullable->get_nested_column_ptr().get());
null_map_cidr = &cidr_column_nullable->get_null_map_data();
} else {
str_cidr_column = assert_cast<const ColumnString*>(cidr_column.get());
}

auto col_res = ColumnUInt8::create(input_rows_count, 0);
auto& col_res_data = col_res->get_data();

for (size_t i = 0; i < input_rows_count; ++i) {
auto addr_idx = index_check_const(i, addr_const);
auto cidr_idx = index_check_const(i, cidr_const);
if (null_map_addr && (*null_map_addr)[addr_idx]) [[unlikely]] {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"The arguments of function {} must be String, not NULL",
get_name());
}
if (null_map_cidr && (*null_map_cidr)[cidr_idx]) [[unlikely]] {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"The arguments of function {} must be String, not NULL",
get_name());
if (is_ipv4(addr_column_with_type_and_name.type)) {
execute_impl_with_ip<PrimitiveType::TYPE_IPV4, ColumnIPv4>(
input_rows_count, addr_const, cidr_const,
assert_cast<const ColumnString*>(cidr_column.get()), addr_column, col_res);
} else if (is_ipv6(addr_column_with_type_and_name.type)) {
execute_impl_with_ip<PrimitiveType::TYPE_IPV6, ColumnIPv6>(
input_rows_count, addr_const, cidr_const,
assert_cast<const ColumnString*>(cidr_column.get()), addr_column, col_res);
} else {
const auto* str_addr_column = assert_cast<const ColumnString*>(addr_column.get());
const auto* str_cidr_column = assert_cast<const ColumnString*>(cidr_column.get());

for (size_t i = 0; i < input_rows_count; ++i) {
auto addr_idx = index_check_const(i, addr_const);
auto cidr_idx = index_check_const(i, cidr_const);

const auto addr =
IPAddressVariant(str_addr_column->get_data_at(addr_idx).to_string_view());
const auto cidr =
parse_ip_with_cidr(str_cidr_column->get_data_at(cidr_idx).to_string_view());
col_res_data[i] = is_address_in_range(addr, cidr) ? 1 : 0;
}
const auto addr =
IPAddressVariant(str_addr_column->get_data_at(addr_idx).to_string_view());
const auto cidr =
parse_ip_with_cidr(str_cidr_column->get_data_at(cidr_idx).to_string_view());
col_res_data[i] = is_address_in_range(addr, cidr) ? 1 : 0;
}

block.replace_by_position(result, std::move(col_res));
Expand Down Expand Up @@ -797,21 +864,6 @@ class FunctionIPv4CIDRToRange : public IFunction {
std::move(col_upper_range_output)}));
return Status::OK();
}

private:
static inline std::pair<UInt32, UInt32> apply_cidr_mask(UInt32 src, UInt8 bits_to_keep) {
if (bits_to_keep >= 8 * sizeof(UInt32)) {
return {src, src};
}
if (bits_to_keep == 0) {
return {static_cast<UInt32>(0), static_cast<UInt32>(-1)};
}
UInt32 mask = static_cast<UInt32>(-1) << (8 * sizeof(UInt32) - bits_to_keep);
UInt32 lower = src & mask;
UInt32 upper = lower | ~mask;

return {lower, upper};
}
};

class FunctionIPv6CIDRToRange : public IFunction {
Expand Down Expand Up @@ -936,18 +988,6 @@ class FunctionIPv6CIDRToRange : public IFunction {
return ColumnStruct::create(
Columns {std::move(col_res_lower_range), std::move(col_res_upper_range)});
}

private:
static void apply_cidr_mask(const char* __restrict src, char* __restrict dst_lower,
char* __restrict dst_upper, UInt8 bits_to_keep) {
// little-endian mask
const auto& mask = get_cidr_mask_ipv6(bits_to_keep);

for (int8_t i = IPV6_BINARY_LENGTH - 1; i >= 0; --i) {
dst_lower[i] = src[i] & mask[i];
dst_upper[i] = dst_lower[i] | ~mask[i];
}
}
};

class FunctionIsIPv4Compat : public IFunction {
Expand Down
27 changes: 27 additions & 0 deletions be/src/vec/runtime/ip_address_cidr.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,33 @@
#include "vec/common/format_ip.h"
namespace doris {

namespace vectorized {
static inline std::pair<UInt32, UInt32> apply_cidr_mask(UInt32 src, UInt8 bits_to_keep) {
if (bits_to_keep >= 8 * sizeof(UInt32)) {
return {src, src};
}
if (bits_to_keep == 0) {
return {static_cast<UInt32>(0), static_cast<UInt32>(-1)};
}
UInt32 mask = static_cast<UInt32>(-1) << (8 * sizeof(UInt32) - bits_to_keep);
UInt32 lower = src & mask;
UInt32 upper = lower | ~mask;

return {lower, upper};
}

static inline void apply_cidr_mask(const char* __restrict src, char* __restrict dst_lower,
char* __restrict dst_upper, UInt8 bits_to_keep) {
// little-endian mask
const auto& mask = get_cidr_mask_ipv6(bits_to_keep);

for (int8_t i = IPV6_BINARY_LENGTH - 1; i >= 0; --i) {
dst_lower[i] = src[i] & mask[i];
dst_upper[i] = dst_lower[i] | ~mask[i];
}
}
} // namespace vectorized

class IPAddressVariant {
public:
explicit IPAddressVariant(std::string_view address_str) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.IPv4Type;
import org.apache.doris.nereids.types.IPv6Type;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.VarcharType;

Expand All @@ -39,6 +41,10 @@ public class IsIpAddressInRange extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE).args(IPv4Type.INSTANCE, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(BooleanType.INSTANCE).args(IPv4Type.INSTANCE, StringType.INSTANCE),
FunctionSignature.ret(BooleanType.INSTANCE).args(IPv6Type.INSTANCE, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(BooleanType.INSTANCE).args(IPv6Type.INSTANCE, StringType.INSTANCE),
FunctionSignature.ret(BooleanType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(BooleanType.INSTANCE).args(StringType.INSTANCE, StringType.INSTANCE));

Expand Down
Loading