Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 70 additions & 17 deletions be/src/vec/functions/array/function_arrays_overlap.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,30 @@ struct OverlapSetImpl {
using ElementNativeType = typename NativeType<typename T::value_type>::Type;
using Set = phmap::flat_hash_set<ElementNativeType, DefaultHash<ElementNativeType>>;
Set set;
void insert_array(const IColumn* column, size_t start, size_t size) {

template <bool nullable>
void insert_array(const IColumn* column, const UInt8* nullmap, size_t start, size_t size) {
const auto& vec = assert_cast<const T&>(*column).get_data();
for (size_t i = start; i < start + size; ++i) {
if constexpr (nullable) {
if (nullmap[i]) {
continue;
}
}
set.insert(vec[i]);
}
}
bool find_any(const IColumn* column, size_t start, size_t size) {

template <bool nullable>
bool find_any(const IColumn* column, const UInt8* nullmap, size_t start, size_t size) {
const auto& vec = assert_cast<const T&>(*column).get_data();
for (size_t i = start; i < start + size; ++i) {
if constexpr (nullable) {
if (nullmap[i]) {
continue;
}
}

if (set.contains(vec[i])) {
return true;
}
Expand All @@ -84,13 +99,28 @@ template <>
struct OverlapSetImpl<ColumnString> {
using Set = phmap::flat_hash_set<StringRef, DefaultHash<StringRef>>;
Set set;
void insert_array(const IColumn* column, size_t start, size_t size) {

template <bool nullable>
void insert_array(const IColumn* column, const UInt8* nullmap, size_t start, size_t size) {
for (size_t i = start; i < start + size; ++i) {
if constexpr (nullable) {
if (nullmap[i]) {
continue;
}
}
set.insert(column->get_data_at(i));
}
}
bool find_any(const IColumn* column, size_t start, size_t size) {

template <bool nullable>
bool find_any(const IColumn* column, const UInt8* nullmap, size_t start, size_t size) {
for (size_t i = start; i < start + size; ++i) {
if constexpr (nullable) {
if (nullmap[i]) {
continue;
}
}

if (set.contains(column->get_data_at(i))) {
return true;
}
Expand Down Expand Up @@ -237,7 +267,6 @@ class FunctionArraysOverlap : public IFunction {
auto dst_null_map = ColumnVector<UInt8>::create(input_rows_count, 0);
UInt8* dst_null_map_data = dst_null_map->get_data().data();

// any array is null or any elements in array is null, return null
RETURN_IF_ERROR(_execute_nullable(left_exec_data, dst_null_map_data));
RETURN_IF_ERROR(_execute_nullable(right_exec_data, dst_null_map_data));

Expand Down Expand Up @@ -334,7 +363,6 @@ class FunctionArraysOverlap : public IFunction {
continue;
}

// any element inside array is NULL, return NULL
if (data.nested_nullmap_data) {
ssize_t start = (*data.offsets_ptr)[row - 1];
ssize_t size = (*data.offsets_ptr)[row] - start;
Expand All @@ -351,14 +379,10 @@ class FunctionArraysOverlap : public IFunction {

template <typename T>
Status _execute_internal(const ColumnArrayExecutionData& left_data,
const ColumnArrayExecutionData& right_data,
const UInt8* dst_nullmap_data, UInt8* dst_data) const {
const ColumnArrayExecutionData& right_data, UInt8* dst_nullmap_data,
UInt8* dst_data) const {
using ExecutorImpl = OverlapSetImpl<T>;
for (ssize_t row = 0; row < left_data.offsets_ptr->size(); ++row) {
if (dst_nullmap_data[row]) {
continue;
}

ssize_t left_start = (*left_data.offsets_ptr)[row - 1];
ssize_t left_size = (*left_data.offsets_ptr)[row] - left_start;
ssize_t right_start = (*right_data.offsets_ptr)[row - 1];
Expand All @@ -368,13 +392,42 @@ class FunctionArraysOverlap : public IFunction {
continue;
}

ExecutorImpl impl;
const auto* small_data = &left_data;
const auto* large_data = &right_data;

ssize_t small_start = left_start;
ssize_t large_start = right_start;
ssize_t small_size = left_size;
ssize_t large_size = right_size;
if (right_size < left_size) {
impl.insert_array(right_data.nested_col, right_start, right_size);
dst_data[row] = impl.find_any(left_data.nested_col, left_start, left_size);
std::swap(small_data, large_data);
std::swap(small_start, large_start);
std::swap(small_size, large_size);
}

ExecutorImpl impl;
if (small_data->nested_nullmap_data) {
impl.template insert_array<true>(small_data->nested_col,
small_data->nested_nullmap_data, small_start,
small_size);
} else {
impl.insert_array(left_data.nested_col, left_start, left_size);
dst_data[row] = impl.find_any(right_data.nested_col, right_start, right_size);
impl.template insert_array<false>(small_data->nested_col,
small_data->nested_nullmap_data, small_start,
small_size);
}

if (large_data->nested_nullmap_data) {
dst_data[row] = impl.template find_any<true>(large_data->nested_col,
large_data->nested_nullmap_data,
large_start, large_size);
} else {
dst_data[row] = impl.template find_any<false>(large_data->nested_col,
large_data->nested_nullmap_data,
large_start, large_size);
}

if (dst_data[row]) {
dst_nullmap_data[row] = 0;
}
}
return Status::OK();
Expand Down
63 changes: 54 additions & 9 deletions be/test/vec/function/function_arrays_overlap_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
// specific language governing permissions and limitations
// under the License.

#include <gtest/gtest.h>

#include <string>

#include "common/status.h"
#include "function_test_util.h"
#include "gtest/gtest_pred_impl.h"
#include "testutil/any_type.h"
#include "vec/core/field.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"

namespace doris::vectorized {
Expand Down Expand Up @@ -113,8 +112,15 @@ TEST(function_arrays_overlap_test, arrays_overlap) {
Array vec1 = {ut_type::DECIMALFIELD(17014116.67), ut_type::DECIMALFIELD(-17014116.67),
ut_type::DECIMALFIELD(0.0)};
Array vec2 = {ut_type::DECIMALFIELD(17014116.67)};
DataSet data_set = {
{{vec1, vec2}, UInt8(1)}, {{Null(), vec1}, Null()}, {{empty_arr, vec1}, UInt8(0)}};

Array vec3 = {ut_type::DECIMALFIELD(17014116.67), ut_type::DECIMALFIELD(-17014116.67),
Null()};
Array vec4 = {ut_type::DECIMALFIELD(-17014116.67)};
Array vec5 = {ut_type::DECIMALFIELD(-17014116.68)};
DataSet data_set = {{{vec1, vec2}, UInt8(1)}, {{Null(), vec1}, Null()},
{{vec1, Null()}, Null()}, {{empty_arr, vec1}, UInt8(0)},
{{vec3, vec4}, UInt8(1)}, {{vec3, vec5}, Null()},
{{vec4, vec3}, UInt8(1)}, {{vec5, vec3}, Null()}};

static_cast<void>(check_function<DataTypeUInt8, true>(func_name, input_types, data_set));
}
Expand All @@ -127,10 +133,49 @@ TEST(function_arrays_overlap_test, arrays_overlap) {
Array vec1 = {Field(String("abc", 3)), Field(String("", 0)), Field(String("def", 3))};
Array vec2 = {Field(String("abc", 3))};
Array vec3 = {Field(String("", 0))};
DataSet data_set = {{{vec1, vec2}, UInt8(1)},
{{vec1, vec3}, UInt8(1)},
{{Null(), vec1}, Null()},
{{empty_arr, vec1}, UInt8(0)}};
Array vec4 = {Field(String("abc", 3)), Null()};
Array vec5 = {Field(String("abcd", 4)), Null()};
DataSet data_set = {{{vec1, vec2}, UInt8(1)}, {{vec1, vec3}, UInt8(1)},
{{Null(), vec1}, Null()}, {{empty_arr, vec1}, UInt8(0)},
{{vec4, vec1}, UInt8(1)}, {{vec1, vec5}, Null()},
{{vec1, vec4}, UInt8(1)}, {{vec5, vec1}, Null()}};

static_cast<void>(check_function<DataTypeUInt8, true>(func_name, input_types, data_set));
}

// arrays_overlap(Array<Decimal128V2>, Array<Decimal128V2>), Non-nullable
{
InputTypeSet input_types = {TypeIndex::Array, TypeIndex::Decimal128V2, TypeIndex::Array,
TypeIndex::Decimal128V2};

Array vec1 = {ut_type::DECIMALFIELD(17014116.67), ut_type::DECIMALFIELD(-17014116.67),
ut_type::DECIMALFIELD(0.0)};
Array vec2 = {ut_type::DECIMALFIELD(17014116.67)};

Array vec3 = {ut_type::DECIMALFIELD(17014116.67), ut_type::DECIMALFIELD(-17014116.67)};
Array vec4 = {ut_type::DECIMALFIELD(-17014116.67)};
Array vec5 = {ut_type::DECIMALFIELD(-17014116.68)};
DataSet data_set = {{{vec1, vec2}, UInt8(1)}, {{empty_arr, vec1}, UInt8(0)},
{{vec3, vec4}, UInt8(1)}, {{vec3, vec5}, UInt8(0)},
{{vec4, vec3}, UInt8(1)}, {{vec5, vec3}, UInt8(0)}};

static_cast<void>(check_function<DataTypeUInt8, true>(func_name, input_types, data_set));
}

// arrays_overlap(Array<String>, Array<String>), Non-nullable
{
InputTypeSet input_types = {TypeIndex::Array, TypeIndex::String, TypeIndex::Array,
TypeIndex::String};

Array vec1 = {Field(String("abc", 3)), Field(String("", 0)), Field(String("def", 3))};
Array vec2 = {Field(String("abc", 3))};
Array vec3 = {Field(String("", 0))};
Array vec4 = {Field(String("abc", 3))};
Array vec5 = {Field(String("abcd", 4))};
DataSet data_set = {{{vec1, vec2}, UInt8(1)}, {{vec1, vec3}, UInt8(1)},
{{empty_arr, vec1}, UInt8(0)}, {{vec4, vec1}, UInt8(1)},
{{vec1, vec5}, UInt8(0)}, {{vec1, vec4}, UInt8(1)},
{{vec5, vec1}, UInt8(0)}};

static_cast<void>(check_function<DataTypeUInt8, true>(func_name, input_types, data_set));
}
Expand Down
3 changes: 3 additions & 0 deletions be/test/vec/function/function_test_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ size_t type_index_to_data_type(const std::vector<AnyType>& input_types, size_t i
return ret;
}
desc.children.push_back(sub_desc.type_desc);
if (sub_desc.is_nullable) {
sub_type = make_nullable(sub_type);
}
type = std::make_shared<DataTypeArray>(sub_type);
return ret + 1;
}
Expand Down
Loading
Loading