Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 158 additions & 73 deletions be/src/vec/exprs/vcompound_pred.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
#include <gen_cpp/Opcodes_types.h>

#include "common/status.h"
#include "gutil/integral_types.h"
#include "util/simd/bits.h"
#include "vec/columns/column.h"
#include "vec/columns/columns_number.h"
#include "vec/common/assert_cast.h"
#include "vec/data_types/data_type_number.h"
#include "vec/exprs/vectorized_fn_call.h"
#include "vec/exprs/vexpr.h"

Expand Down Expand Up @@ -55,86 +57,166 @@ class VCompoundPred : public VectorizedFnCall {

Status execute(VExprContext* context, vectorized::Block* block,
int* result_column_id) override {
if (children().size() == 1 || !_all_child_is_compound_and_not_const() ||
_children[0]->is_nullable() || _children[1]->is_nullable()) {
// TODO:
// When the child is nullable, make the optimization also take effect, and the processing of this piece may be more complicated
// https://dev.mysql.com/doc/refman/8.0/en/logical-operators.html
if (children().size() == 1 || !_all_child_is_compound_and_not_const()) {
return VectorizedFnCall::execute(context, block, result_column_id);
}

int lhs_id = -1;
int rhs_id = -1;
RETURN_IF_ERROR(_children[0]->execute(context, block, &lhs_id));
ColumnPtr lhs_column =
block->get_by_position(lhs_id).column->convert_to_full_column_if_const();

ColumnPtr lhs_column = block->get_by_position(lhs_id).column;
size_t size = lhs_column->size();
uint8* __restrict data = _get_raw_data(lhs_column);
int filted = simd::count_zero_num((int8_t*)data, size);
bool full = filted == 0;
bool empty = filted == size;
bool lhs_is_nullable = lhs_column->is_nullable();
auto [lhs_data_column, lhs_null_map] =
_get_raw_data_and_null_map(lhs_column, lhs_is_nullable);
int filted = simd::count_zero_num((int8_t*)lhs_data_column, size);
bool lhs_all_true = (filted == 0);
bool lhs_all_false = (filted == size);

bool lhs_all_is_not_null = false;
if (lhs_is_nullable) {
filted = simd::count_zero_num((int8_t*)lhs_null_map, size);
lhs_all_is_not_null = (filted == size);
}

ColumnPtr rhs_column = nullptr;
uint8* __restrict data_rhs = nullptr;
bool full_rhs = false;
bool empty_rhs = false;
uint8* __restrict rhs_data_column = nullptr;
uint8* __restrict rhs_null_map = nullptr;
bool rhs_is_nullable = false;
bool rhs_all_true = false;
bool rhs_all_false = false;
bool rhs_all_is_not_null = false;
bool result_is_nullable = _data_type->is_nullable();

auto get_rhs_colum = [&]() {
if (rhs_id == -1) {
RETURN_IF_ERROR(_children[1]->execute(context, block, &rhs_id));
rhs_column =
block->get_by_position(rhs_id).column->convert_to_full_column_if_const();
data_rhs = _get_raw_data(rhs_column);
int filted = simd::count_zero_num((int8_t*)data_rhs, size);
full_rhs = filted == 0;
empty_rhs = filted == size;
rhs_column = block->get_by_position(rhs_id).column;
rhs_is_nullable = rhs_column->is_nullable();
auto rhs_nullable_column = _get_raw_data_and_null_map(rhs_column, rhs_is_nullable);
rhs_data_column = rhs_nullable_column.first;
rhs_null_map = rhs_nullable_column.second;
int filted = simd::count_zero_num((int8_t*)rhs_data_column, size);
rhs_all_true = (filted == 0);
rhs_all_false = (filted == size);
if (rhs_is_nullable) {
filted = simd::count_zero_num((int8_t*)rhs_null_map, size);
rhs_all_is_not_null = (filted == size);
}
}
return Status::OK();
};

auto return_result_column_id = [&](ColumnPtr res_column, int res_id) -> int {
if (result_is_nullable && !res_column->is_nullable()) {
auto result_column =
ColumnNullable::create(res_column, ColumnUInt8::create(size, 0));
res_id = block->columns();
block->insert({std::move(result_column), _data_type, _expr_name});
}
return res_id;
};

auto create_null_map_column = [&](ColumnPtr null_map_column,
uint8* __restrict null_map_data) {
if (null_map_data == nullptr) {
null_map_column = ColumnUInt8::create(size, 0);
null_map_data = assert_cast<ColumnUInt8*>(null_map_column->assume_mutable().get())
->get_data()
.data();
}
return null_map_data;
};

auto vector_vector_null = [&]<bool is_and_op>() {
auto col_res = ColumnUInt8::create(size);
auto col_nulls = ColumnUInt8::create(size);
auto* __restrict res_datas = assert_cast<ColumnUInt8*>(col_res)->get_data().data();
auto* __restrict res_nulls = assert_cast<ColumnUInt8*>(col_nulls)->get_data().data();
ColumnPtr temp_null_map = nullptr;
// maybe both children are nullable / or one of children is nullable
lhs_null_map = create_null_map_column(temp_null_map, lhs_null_map);
rhs_null_map = create_null_map_column(temp_null_map, rhs_null_map);

if constexpr (is_and_op) {
for (size_t i = 0; i < size; ++i) {
res_nulls[i] = apply_and_null(lhs_data_column[i], lhs_null_map[i],
rhs_data_column[i], rhs_null_map[i]);
res_datas[i] = lhs_data_column[i] & rhs_data_column[i];
}
} else {
for (size_t i = 0; i < size; ++i) {
res_nulls[i] = apply_or_null(lhs_data_column[i], lhs_null_map[i],
rhs_data_column[i], rhs_null_map[i]);
res_datas[i] = lhs_data_column[i] | rhs_data_column[i];
}
}
auto result_column = ColumnNullable::create(std::move(col_res), std::move(col_nulls));
*result_column_id = block->columns();
block->insert({std::move(result_column), _data_type, _expr_name});
};

// false and NULL ----> 0
// true and NULL ----> NULL
if (_op == TExprOpcode::COMPOUND_AND) {
if (empty) {
// empty and any = empty, return lhs
*result_column_id = lhs_id;
//1. not null column: all data is false
//2. nullable column: null map all is not null
if ((lhs_all_false && !lhs_is_nullable) || (lhs_all_false && lhs_all_is_not_null)) {
// false and any = false, return lhs
*result_column_id = return_result_column_id(lhs_column, lhs_id);
} else {
RETURN_IF_ERROR(get_rhs_colum());

if (full) {
// full and any = any, return rhs
*result_column_id = rhs_id;
} else if (empty_rhs) {
// any and empty = empty, return rhs
*result_column_id = rhs_id;
} else if (full_rhs) {
// any and full = any, return lhs
*result_column_id = lhs_id;
if ((lhs_all_true && !lhs_is_nullable) || //not null column
(lhs_all_true && lhs_all_is_not_null)) { //nullable column
// true and any = any, return rhs
*result_column_id = return_result_column_id(rhs_column, rhs_id);
} else if ((rhs_all_false && !rhs_is_nullable) ||
(rhs_all_false && rhs_all_is_not_null)) {
// any and false = false, return rhs
*result_column_id = return_result_column_id(rhs_column, rhs_id);
} else if ((rhs_all_true && !rhs_is_nullable) ||
(rhs_all_true && rhs_all_is_not_null)) {
// any and true = any, return lhs
*result_column_id = return_result_column_id(lhs_column, lhs_id);
} else {
*result_column_id = lhs_id;
for (size_t i = 0; i < size; i++) {
data[i] &= data_rhs[i];
if (!result_is_nullable) {
*result_column_id = lhs_id;
for (size_t i = 0; i < size; i++) {
lhs_data_column[i] &= rhs_data_column[i];
}
} else {
vector_vector_null.template operator()<true>();
}
}
}
} else if (_op == TExprOpcode::COMPOUND_OR) {
if (full) {
// full or any = full, return lhs
*result_column_id = lhs_id;
// true or NULL ----> 1
// false or NULL ----> NULL
if ((lhs_all_true && !lhs_is_nullable) || (lhs_all_true && lhs_all_is_not_null)) {
// true or any = true, return lhs
*result_column_id = return_result_column_id(lhs_column, lhs_id);
} else {
RETURN_IF_ERROR(get_rhs_colum());
if (empty) {
// empty or any = any, return rhs
*result_column_id = rhs_id;
} else if (full_rhs) {
// any or full = full, return rhs
*result_column_id = rhs_id;
} else if (empty_rhs) {
// any or empty = any, return lhs
*result_column_id = lhs_id;
if ((lhs_all_false && !lhs_is_nullable) || (lhs_all_false && lhs_all_is_not_null)) {
// false or any = any, return rhs
*result_column_id = return_result_column_id(rhs_column, rhs_id);
} else if ((rhs_all_true && !rhs_is_nullable) ||
(rhs_all_true && rhs_all_is_not_null)) {
// any or true = true, return rhs
*result_column_id = return_result_column_id(rhs_column, rhs_id);
} else if ((rhs_all_false && !rhs_is_nullable) ||
(rhs_all_false && rhs_all_is_not_null)) {
// any or false = any, return lhs
*result_column_id = return_result_column_id(lhs_column, lhs_id);
} else {
*result_column_id = lhs_id;
for (size_t i = 0; i < size; i++) {
data[i] |= data_rhs[i];
if (!result_is_nullable) {
*result_column_id = lhs_id;
for (size_t i = 0; i < size; i++) {
lhs_data_column[i] |= rhs_data_column[i];
}
} else {
vector_vector_null.template operator()<false>();
}
}
}
Expand All @@ -148,6 +230,15 @@ class VCompoundPred : public VectorizedFnCall {
bool is_compound_predicate() const override { return true; }

private:
static inline constexpr uint8 apply_and_null(UInt8 a, UInt8 l_null, UInt8 b, UInt8 r_null) {
// (<> && false) is false, (true && NULL) is NULL
return (l_null & r_null) | (r_null & (l_null ^ a)) | (l_null & (r_null ^ b));
}
static inline constexpr uint8 apply_or_null(UInt8 a, UInt8 l_null, UInt8 b, UInt8 r_null) {
// (<> || true) is true, (false || NULL) is NULL
return (l_null & r_null) | (r_null & (r_null ^ a)) | (l_null & (l_null ^ b));
}

bool _all_child_is_compound_and_not_const() const {
for (auto child : _children) {
// we can make sure non const compound predicate's return column is allow modifyied locally.
Expand All @@ -158,29 +249,23 @@ class VCompoundPred : public VectorizedFnCall {
return true;
}

uint8* _get_raw_data(ColumnPtr column) const {
if (column->is_nullable()) {
return assert_cast<ColumnUInt8*>(
assert_cast<ColumnNullable*>(column->assume_mutable().get())
->get_nested_column_ptr()
.get())
->get_data()
.data();
} else {
return assert_cast<ColumnUInt8*>(column->assume_mutable().get())->get_data().data();
}
}

uint8* _get_null_map(ColumnPtr column) const {
if (column->is_nullable()) {
return assert_cast<ColumnUInt8*>(
assert_cast<ColumnNullable*>(column->assume_mutable().get())
->get_null_map_column_ptr()
.get())
->get_data()
.data();
std::pair<uint8*, uint8*> _get_raw_data_and_null_map(ColumnPtr column,
bool nullable_column) const {
if (nullable_column) {
auto* nullable_column = assert_cast<ColumnNullable*>(column->assume_mutable().get());
auto* data_column =
assert_cast<ColumnUInt8*>(nullable_column->get_nested_column_ptr().get())
->get_data()
.data();
auto* null_map =
assert_cast<ColumnUInt8*>(nullable_column->get_null_map_column_ptr().get())
->get_data()
.data();
return std::make_pair(data_column, null_map);
} else {
return nullptr;
auto* data_column =
assert_cast<ColumnUInt8*>(column->assume_mutable().get())->get_data().data();
return std::make_pair(data_column, nullptr);
}
}

Expand Down
Loading