Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions be/src/vec/functions/array/function_array_filter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <fmt/format.h>
#include <glog/logging.h>
#include <stddef.h>

#include <memory>
#include <ostream>
#include <string>
#include <utility>

#include "common/status.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column.h"
#include "vec/columns/column_vector.h"
#include "vec/columns/columns_number.h"
#include "vec/core/block.h"
#include "vec/core/column_numbers.h"
#include "vec/core/column_with_type_and_name.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
#include "vec/functions/array/function_array_utils.h"
#include "vec/functions/function.h"
#include "vec/functions/simple_function_factory.h"

namespace doris {
class FunctionContext;
} // namespace doris

namespace doris::vectorized {

class FunctionArrayFilter : public IFunction {
public:
static constexpr auto name = "array_filter";
static FunctionPtr create() { return std::make_shared<FunctionArrayFilter>(); }

/// Get function name.
String get_name() const override { return name; }

bool is_variadic() const override { return false; }

size_t get_number_of_arguments() const override { return 2; }

DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
DCHECK(is_array(arguments[0]))
<< "First argument for function: " << name
<< " should be DataTypeArray but it has type " << arguments[0]->get_name() << ".";
return arguments[0];
}

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) override {
//TODO: maybe need optimize not convert
auto first_column =
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
auto second_column =
block.get_by_position(arguments[1]).column->convert_to_full_column_if_const();

const ColumnArray& first_col_array = assert_cast<const ColumnArray&>(*first_column);
const auto& first_off_data =
assert_cast<const ColumnArray::ColumnOffsets&>(first_col_array.get_offsets_column())
.get_data();
const auto& first_nested_nullable_column =
assert_cast<const ColumnNullable&>(*first_col_array.get_data_ptr());

const ColumnArray& second_col_array = assert_cast<const ColumnArray&>(*second_column);
const auto& second_off_data = assert_cast<const ColumnArray::ColumnOffsets&>(
second_col_array.get_offsets_column())
.get_data();
const auto& second_nested_null_map_data =
assert_cast<const ColumnNullable&>(*second_col_array.get_data_ptr())
.get_null_map_column()
.get_data();
const auto& second_nested_column =
assert_cast<const ColumnNullable&>(*second_col_array.get_data_ptr())
.get_nested_column();
const auto& second_nested_data =
assert_cast<const ColumnUInt8&>(second_nested_column).get_data();

auto result_data_column = first_nested_nullable_column.clone_empty();
auto result_offset_column = ColumnArray::ColumnOffsets::create();
auto& result_offset_data = result_offset_column->get_data();
vectorized::IColumn::Selector selector;
selector.reserve(first_off_data.size());
result_offset_data.reserve(input_rows_count);

for (size_t row = 0; row < input_rows_count; ++row) {
unsigned long count = 0;
auto first_offset_start = first_off_data[row - 1];
auto first_offset_end = first_off_data[row];
auto second_offset_start = second_off_data[row - 1];
auto second_offset_end = second_off_data[row];
auto move_off = second_offset_start;
for (auto off = first_offset_start;
off < first_offset_end && move_off < second_offset_end; // not out range
++off) {
if (second_nested_null_map_data[move_off] == 0 && // not null
second_nested_data[move_off] == 1) { // not 0
count++;
selector.push_back(off);
}
move_off++;
}
result_offset_data.push_back(count + result_offset_data.back());
}
first_nested_nullable_column.append_data_by_selector(result_data_column, selector);

auto res_column =
ColumnArray::create(std::move(result_data_column), std::move(result_offset_column));
block.replace_by_position(result, std::move(res_column));
return Status::OK();
}
};

void register_function_array_filter_function(SimpleFunctionFactory& factory) {
factory.register_function<FunctionArrayFilter>();
}

} // namespace doris::vectorized
2 changes: 2 additions & 0 deletions be/src/vec/functions/array/function_array_register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ void register_function_array_pushback(SimpleFunctionFactory& factory);
void register_function_array_first_or_last_index(SimpleFunctionFactory& factory);
void register_function_array_cum_sum(SimpleFunctionFactory& factory);
void register_function_array_count(SimpleFunctionFactory&);
void register_function_array_filter_function(SimpleFunctionFactory&);

void register_function_array(SimpleFunctionFactory& factory) {
register_function_array_shuffle(factory);
Expand Down Expand Up @@ -88,6 +89,7 @@ void register_function_array(SimpleFunctionFactory& factory) {
register_function_array_first_or_last_index(factory);
register_function_array_cum_sum(factory);
register_function_array_count(factory);
register_function_array_filter_function(factory);
}

} // namespace doris::vectorized
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,18 @@ array_filter(lambda,array)

</version>

<version since="2.0.2">

array array_filter(array arr, array_bool filter_column)

</version>

### description

#### Syntax
```sql
ARRAY<T> array_filter(lambda, ARRAY<T> arr1, ARRAY<T> arr2, ... )
ARRAY<T> array_filter(ARRAY<T> arr)
ARRAY<T> array_filter(lambda, ARRAY<T> arr)
ARRAY<T> array_filter(ARRAY<T> arr, ARRAY<Bool> filter_column)
```

Use the lambda expression as the input parameter to calculate and filter the data of the ARRAY column of the other input parameter.
Expand All @@ -47,11 +53,21 @@ And filter out the values of 0 and NULL in the result.
array_filter(x->x>0, array1);
array_filter(x->(x+2)=10, array1);
array_filter(x->(abs(x)-2)>0, array1);
array_filter(c_array,[0,1,0]);
```

### example

```shell
mysql [test]>select c_array,array_filter(c_array,[0,1,0]) from array_test;
+-----------------+----------------------------------------------------+
| c_array | array_filter(`c_array`, ARRAY(FALSE, TRUE, FALSE)) |
+-----------------+----------------------------------------------------+
| [1, 2, 3, 4, 5] | [2] |
| [6, 7, 8] | [7] |
| [] | [] |
| NULL | NULL |
+-----------------+----------------------------------------------------+

mysql [test]>select array_filter(x->(x > 1),[1,2,3,0,null]);
+----------------------------------------------------------------------------------------------+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,18 @@ array_filter(lambda,array)

</version>

<version since="2.0.2">

array array_filter(array arr, array_bool filter_column)

</version>

### description

#### Syntax
```sql
ARRAY<T> array_filter(lambda, ARRAY<T> arr1, ARRAY<T> arr2, ... )
ARRAY<T> array_filter(ARRAY<T> arr)
ARRAY<T> array_filter(lambda, ARRAY<T> arr)
ARRAY<T> array_filter(ARRAY<T> arr, ARRAY<Bool> filter_column)
```

使用lambda表达式作为输入参数,计算筛选另外的输入参数ARRAY列的数据。
Expand All @@ -47,12 +53,22 @@ ARRAY<T> array_filter(ARRAY<T> arr)
array_filter(x->x>0, array1);
array_filter(x->(x+2)=10, array1);
array_filter(x->(abs(x)-2)>0, array1);

array_filter(c_array,[0,1,0]);
```

### example

```shell
mysql [test]>select c_array,array_filter(c_array,[0,1,0]) from array_test;
+-----------------+----------------------------------------------------+
| c_array | array_filter(`c_array`, ARRAY(FALSE, TRUE, FALSE)) |
+-----------------+----------------------------------------------------+
| [1, 2, 3, 4, 5] | [2] |
| [6, 7, 8] | [7] |
| [] | [] |
| NULL | NULL |
+-----------------+----------------------------------------------------+

mysql [test]>select array_filter(x->(x > 1),[1,2,3,0,null]);
+----------------------------------------------------------------------------------------------+
| array_filter(ARRAY(1, 2, 3, 0, NULL), array_map([x] -> (x(0) > 1), ARRAY(1, 2, 3, 0, NULL))) |
Expand Down
5 changes: 5 additions & 0 deletions fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
Original file line number Diff line number Diff line change
Expand Up @@ -1845,6 +1845,11 @@ private boolean findSlotRefByName(String colName) {
if (slot.getColumnName() != null && slot.getColumnName().equals(colName)) {
return true;
}
} else if (this instanceof ColumnRefExpr) {
ColumnRefExpr slot = (ColumnRefExpr) this;
if (slot.getName() != null && slot.getName().equals(colName)) {
return true;
}
}
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,4 +263,37 @@ protected void toThrift(TExprNode msg) {
msg.node_type = TExprNodeType.LAMBDA_FUNCTION_CALL_EXPR;
}
}

@Override
public String toSqlImpl() {
StringBuilder sb = new StringBuilder();
sb.append(getFnName().getFunction());
sb.append("(");
int childSize = children.size();
Expr lastExpr = getChild(childSize - 1);
// eg: select array_map(x->x>10, k1) from table,
// but we need analyze each param, so change the function like this in parser
// array_map(x->x>10, k1) ---> array_map(k1, x->x>10),
// so maybe the lambda expr is the end position. and need this check.
boolean lastIsLambdaExpr = (lastExpr instanceof LambdaFunctionExpr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment for the code, trickly here

if (lastIsLambdaExpr) {
sb.append(lastExpr.toSql());
sb.append(", ");
}
for (int i = 0; i < childSize - 1; ++i) {
sb.append(getChild(i).toSql());
if (i != childSize - 2) {
sb.append(", ");
}
}
// and some functions is only implement as a normal array function;
// but also want use as lambda function, select array_sortby(x->x,['b','a','c']);
// so we convert to: array_sortby(array('b', 'a', 'c'), array_map(x -> `x`, array('b', 'a', 'c')))
if (lastIsLambdaExpr == false) {
sb.append(", ");
sb.append(lastExpr.toSql());
}
sb.append(")");
return sb.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,20 @@ protected void analyzeImpl(Analyzer analyzer) throws AnalysisException {

@Override
protected String toSqlImpl() {
return String.format("%s -> %s", names.toString(), getChild(0).toSql());
String nameStr = "";
Expr lambdaExpr = slotExpr.get(0);
int exprSize = names.size();
for (int i = 0; i < exprSize; ++i) {
nameStr = nameStr + names.get(i);
if (i != exprSize - 1) {
nameStr = nameStr + ",";
}
}
if (exprSize > 1) {
nameStr = "(" + nameStr + ")";
}
String res = String.format("%s -> %s", nameStr, lambdaExpr.toSql());
return res;
}

@Override
Expand Down
26 changes: 26 additions & 0 deletions regression-test/data/ddl_p0/test_create_view.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !test_view_1 --
1 [1, 2, 3]
2 [10, -2, 8]
3 [-1, 20, 0]

-- !test_view_2 --
1 [1, 2, 3] [1, 1, 1]
2 [10, -2, 8] [1, 0, 1]
3 [-1, 20, 0] [0, 1, 0]

-- !test_view_3 --
1 [1, 2, 3] [1, 2, 3] [1, 2, 3]
2 [10, -2, 8] [10, 8] [10, 8]
3 [-1, 20, 0] [20] [20]

-- !test_view_4 --
1 [1, 2, 3] [1, 2, 3] [1, 2, 3]
2 [10, -2, 8] [10, 8] [10, 8]
3 [-1, 20, 0] [20] [20]

-- !test_view_5 --
1 [1, 2, 3] [1, 1, 1]
2 [10, -2, 8] [1, 0, 1]
3 [-1, 20, 0] [0, 1, 0]

42 changes: 42 additions & 0 deletions regression-test/suites/ddl_p0/test_create_view.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,46 @@ suite("test_create_view") {
sql """DROP VIEW IF EXISTS my_view"""
sql """DROP TABLE IF EXISTS t1"""
sql """DROP TABLE IF EXISTS t2"""


sql """DROP TABLE IF EXISTS view_baseall"""
sql """DROP VIEW IF EXISTS test_view7"""
sql """DROP VIEW IF EXISTS test_view8"""
sql """
CREATE TABLE `view_baseall` (
`k1` int(11) NULL,
`k3` array<int> NULL
) ENGINE=OLAP
DUPLICATE KEY(`k1`)
COMMENT 'OLAP'
DISTRIBUTED BY HASH(`k1`) BUCKETS 5
PROPERTIES (
"replication_allocation" = "tag.location.default: 1",
"is_being_synced" = "false",
"storage_format" = "V2",
"light_schema_change" = "true",
"disable_auto_compaction" = "false",
"enable_single_replica_compaction" = "false"
);
"""
sql """insert into view_baseall values(1,[1,2,3]);"""
sql """insert into view_baseall values(2,[10,-2,8]);"""
sql """insert into view_baseall values(3,[-1,20,0]);"""

qt_test_view_1 """ select * from view_baseall order by k1; """
qt_test_view_2 """ select *, array_map(x->x>0,k3) from view_baseall order by k1; """
qt_test_view_3 """ select *, array_filter(x->x>0,k3),array_filter(`k3`, array_map(x -> x > 0, `k3`)) from view_baseall order by k1; """


sql """
create view IF NOT EXISTS test_view7 (k1,k2,k3,k4) as
select *, array_filter(x->x>0,k3),array_filter(`k3`, array_map(x -> x > 0, `k3`)) from view_baseall order by k1;
"""
qt_test_view_4 """ select * from test_view7 order by k1; """

sql """
create view IF NOT EXISTS test_view8 (k1,k2,k3) as
select *, array_map(x->x>0,k3) from view_baseall order by k1;
"""
qt_test_view_5 """ select * from test_view8 order by k1; """
}