Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion be/src/vec/functions/function_java_udf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "util/jni-util.h"
#include "vec/columns/column.h"
#include "vec/columns/column_array.h"
#include "vec/columns/column_map.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_string.h"
#include "vec/columns/column_vector.h"
Expand Down Expand Up @@ -71,7 +72,8 @@ Status JavaFunctionCall::open(FunctionContext* context, FunctionContext::Functio
jni_env->executor_cl, "convertBasicArguments", "(IZIJJJ)[Ljava/lang/Object;");
jni_env->executor_convert_array_argument_id = env->GetMethodID(
jni_env->executor_cl, "convertArrayArguments", "(IZIJJJJJ)[Ljava/lang/Object;");

jni_env->executor_convert_map_argument_id = env->GetMethodID(
jni_env->executor_cl, "convertMapArguments", "(IZIJJJJJJJJ)[Ljava/lang/Object;");
jni_env->executor_result_basic_batch_id = env->GetMethodID(
jni_env->executor_cl, "copyBatchBasicResult", "(ZI[Ljava/lang/Object;JJJ)V");
jni_env->executor_result_array_batch_id = env->GetMethodID(
Expand Down Expand Up @@ -148,6 +150,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
ColumnPtr null_cols[arg_size];
jclass obj_class = env->FindClass("[Ljava/lang/Object;");
jclass arraylist_class = env->FindClass("Ljava/util/ArrayList;");
// jclass hashmap_class = env->FindClass("Ljava/util/HashMap;");
jobjectArray arg_objects = env->NewObjectArray(arg_size, obj_class, nullptr);
int64_t nullmap_address = 0;
for (size_t arg_idx = 0; arg_idx < arg_size; ++arg_idx) {
Expand Down Expand Up @@ -218,6 +221,54 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
jni_env->executor_convert_array_argument_id, arg_idx, arg_column_nullable,
num_rows, nullmap_address, offset_address, nested_nullmap_address,
nested_data_address, nested_offset_address);
} else if (data_cols[arg_idx]->is_column_map()) {
const ColumnMap* map_col = assert_cast<const ColumnMap*>(data_cols[arg_idx].get());
auto offset_address =
reinterpret_cast<int64_t>(map_col->get_offsets_column().get_raw_data().data);
const ColumnNullable& map_key_column_nullable =
assert_cast<const ColumnNullable&>(map_col->get_keys());
auto key_data_column_null_map = map_key_column_nullable.get_null_map_column_ptr();
auto key_data_column = map_key_column_nullable.get_nested_column_ptr();

auto key_nested_nullmap_address = reinterpret_cast<int64_t>(
check_and_get_column<ColumnVector<UInt8>>(key_data_column_null_map)
->get_data()
.data());
int64_t key_nested_data_address = 0, key_nested_offset_address = 0;
if (key_data_column->is_column_string()) {
const ColumnString* col = assert_cast<const ColumnString*>(key_data_column.get());
key_nested_data_address = reinterpret_cast<int64_t>(col->get_chars().data());
key_nested_offset_address = reinterpret_cast<int64_t>(col->get_offsets().data());
} else {
key_nested_data_address =
reinterpret_cast<int64_t>(key_data_column->get_raw_data().data);
}

const ColumnNullable& map_value_column_nullable =
assert_cast<const ColumnNullable&>(map_col->get_values());
auto value_data_column_null_map = map_value_column_nullable.get_null_map_column_ptr();
auto value_data_column = map_value_column_nullable.get_nested_column_ptr();
auto value_nested_nullmap_address = reinterpret_cast<int64_t>(
check_and_get_column<ColumnVector<UInt8>>(value_data_column_null_map)
->get_data()
.data());
int64_t value_nested_data_address = 0, value_nested_offset_address = 0;
// array type need pass address: [nullmap_address], offset_address, nested_nullmap_address, nested_data_address/nested_char_address,nested_offset_address
if (value_data_column->is_column_string()) {
const ColumnString* col = assert_cast<const ColumnString*>(value_data_column.get());
value_nested_data_address = reinterpret_cast<int64_t>(col->get_chars().data());
value_nested_offset_address = reinterpret_cast<int64_t>(col->get_offsets().data());
} else {
value_nested_data_address =
reinterpret_cast<int64_t>(value_data_column->get_raw_data().data);
}
arr_obj = (jobjectArray)env->CallNonvirtualObjectMethod(
jni_ctx->executor, jni_env->executor_cl,
jni_env->executor_convert_map_argument_id, arg_idx, arg_column_nullable,
num_rows, nullmap_address, offset_address, key_nested_nullmap_address,
key_nested_data_address, key_nested_offset_address,
value_nested_nullmap_address, value_nested_data_address,
value_nested_offset_address);
} else {
return Status::InvalidArgument(
strings::Substitute("Java UDF doesn't support type $0 now !",
Expand Down
1 change: 1 addition & 0 deletions be/src/vec/functions/function_java_udf.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class JavaFunctionCall : public IFunctionBase {
jmethodID executor_evaluate_id;
jmethodID executor_convert_basic_argument_id;
jmethodID executor_convert_array_argument_id;
jmethodID executor_convert_map_argument_id;
jmethodID executor_result_basic_batch_id;
jmethodID executor_result_array_batch_id;
jmethodID executor_close_id;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.common.jni.utils;

import org.apache.doris.catalog.ArrayType;
import org.apache.doris.catalog.MapType;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
Expand Down Expand Up @@ -52,7 +53,7 @@
import java.util.Set;

public class UdfUtils {
private static final Logger LOG = Logger.getLogger(UdfUtils.class);
public static final Logger LOG = Logger.getLogger(UdfUtils.class);
public static final Unsafe UNSAFE;
private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L;
public static final long BYTE_ARRAY_OFFSET;
Expand Down Expand Up @@ -95,15 +96,16 @@ public enum JavaUdfDataType {
DECIMAL32("DECIMAL32", TPrimitiveType.DECIMAL32, 4),
DECIMAL64("DECIMAL64", TPrimitiveType.DECIMAL64, 8),
DECIMAL128("DECIMAL128", TPrimitiveType.DECIMAL128I, 16),
ARRAY_TYPE("ARRAY_TYPE", TPrimitiveType.ARRAY, 0);

ARRAY_TYPE("ARRAY_TYPE", TPrimitiveType.ARRAY, 0),
MAP_TYPE("MAP_TYPE", TPrimitiveType.MAP, 0);
private final String description;
private final TPrimitiveType thriftType;
private final int len;
private int precision;
private int scale;
private Type itemType;

private Type keyType;
private Type valueType;
JavaUdfDataType(String description, TPrimitiveType thriftType, int len) {
this.description = description;
this.thriftType = thriftType;
Expand Down Expand Up @@ -153,6 +155,8 @@ public static Set<JavaUdfDataType> getCandidateTypes(Class<?> c) {
JavaUdfDataType.DECIMAL128);
} else if (c == java.util.ArrayList.class) {
return Sets.newHashSet(JavaUdfDataType.ARRAY_TYPE);
} else if (c == java.util.HashMap.class) {
return Sets.newHashSet(JavaUdfDataType.MAP_TYPE);
}
return Sets.newHashSet(JavaUdfDataType.INVALID_TYPE);
}
Expand Down Expand Up @@ -192,6 +196,22 @@ public Type getItemType() {
public void setItemType(Type type) {
this.itemType = type;
}

public Type getKeyType() {
return keyType;
}

public Type getValueType() {
return valueType;
}

public void setKeyType(Type type) {
this.keyType = type;
}

public void setValueType(Type type) {
this.valueType = type;
}
}

public static Pair<Type, Integer> fromThrift(TTypeDesc typeDesc, int nodeIdx) throws InternalException {
Expand Down Expand Up @@ -232,6 +252,14 @@ public static Pair<Type, Integer> fromThrift(TTypeDesc typeDesc, int nodeIdx) th
nodeIdx = childType.second;
break;
}
case MAP: {
Preconditions.checkState(nodeIdx + 1 < typeDesc.getTypesSize());
Pair<Type, Integer> keyType = fromThrift(typeDesc, nodeIdx + 1);
Pair<Type, Integer> valueType = fromThrift(typeDesc, nodeIdx + 1 + keyType.value());
type = new MapType(keyType.key(), valueType.key());
nodeIdx = 1 + keyType.value() + valueType.value();
break;
}

default:
throw new InternalException("Return type " + node.getType() + " is not supported now!");
Expand Down Expand Up @@ -307,6 +335,14 @@ public static Pair<Boolean, JavaUdfDataType> setReturnType(Type retType, Class<?
result.setPrecision(arrType.getItemType().getPrecision());
result.setScale(((ScalarType) arrType.getItemType()).getScalarScale());
}
} else if (retType.isMapType()) {
MapType mapType = (MapType) retType;
result.setKeyType(mapType.getKeyType());
result.setValueType(mapType.getValueType());
if (mapType.getKeyType().isDatetimeV2() || mapType.getKeyType().isDecimalV3()) {
result.setPrecision(mapType.getKeyType().getPrecision());
result.setScale(((ScalarType) mapType.getKeyType()).getScalarScale());
}
}
return Pair.of(res.length != 0, result);
}
Expand All @@ -332,6 +368,10 @@ public static Pair<Boolean, JavaUdfDataType[]> setArgTypes(Type[] parameterTypes
} else if (parameterTypes[finalI].isArrayType()) {
ArrayType arrType = (ArrayType) parameterTypes[finalI];
inputArgTypes[i].setItemType(arrType.getItemType());
} else if (parameterTypes[finalI].isMapType()) {
MapType mapType = (MapType) parameterTypes[finalI];
inputArgTypes[i].setKeyType(mapType.getKeyType());
inputArgTypes[i].setValueType(mapType.getValueType());
}
if (res.length == 0) {
return Pair.of(false, inputArgTypes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1202,4 +1202,123 @@ public Object[] convertArrayArg(int argIdx, boolean isNullable, int rowStart, in
}
return argument;
}

public Object[] convertMapArg(PrimitiveType type, int argIdx, boolean isNullable, int rowStart, int rowEnd,
long nullMapAddr,
long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) {
Object[] argument = (Object[]) Array.newInstance(ArrayList.class, rowEnd - rowStart);
for (int row = rowStart; row < rowEnd; ++row) {
long offsetStart = UdfUtils.UNSAFE.getLong(null, offsetsAddr + 8L * (row - 1));
long offsetEnd = UdfUtils.UNSAFE.getLong(null, offsetsAddr + 8L * (row));
int currentRowNum = (int) (offsetEnd - offsetStart);
switch (type) {
case BOOLEAN: {
argument[row
- rowStart] = UdfConvert
.convertArrayBooleanArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case TINYINT: {
argument[row - rowStart] = UdfConvert
.convertArrayTinyIntArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case SMALLINT: {
argument[row - rowStart] = UdfConvert
.convertArraySmallIntArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case INT: {
argument[row - rowStart] = UdfConvert
.convertArrayIntArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case BIGINT: {
argument[row - rowStart] = UdfConvert
.convertArrayBigIntArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case LARGEINT: {
argument[row - rowStart] = UdfConvert
.convertArrayLargeIntArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case FLOAT: {
argument[row - rowStart] = UdfConvert
.convertArrayFloatArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case DOUBLE: {
argument[row - rowStart] = UdfConvert
.convertArrayDoubleArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case CHAR:
case VARCHAR:
case STRING: {
argument[row - rowStart] = UdfConvert
.convertArrayStringArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr, strOffsetAddr);
break;
}
case DATE: {
argument[row - rowStart] = UdfConvert
.convertArrayDateArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case DATETIME: {
argument[row - rowStart] = UdfConvert
.convertArrayDateTimeArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case DATEV2: {
argument[row - rowStart] = UdfConvert
.convertArrayDateV2Arg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case DATETIMEV2: {
argument[row - rowStart] = UdfConvert
.convertArrayDateTimeV2Arg(row, currentRowNum, offsetStart, isNullable,
nullMapAddr, nestedNullMapAddr, dataAddr);
break;
}
case DECIMALV2:
case DECIMAL128: {
argument[row - rowStart] = UdfConvert
.convertArrayDecimalArg(argTypes[argIdx].getScale(), 16L, row, currentRowNum,
offsetStart, isNullable, nullMapAddr, nestedNullMapAddr, dataAddr);
break;
}
case DECIMAL32: {
argument[row - rowStart] = UdfConvert
.convertArrayDecimalArg(argTypes[argIdx].getScale(), 4L, row, currentRowNum,
offsetStart, isNullable, nullMapAddr, nestedNullMapAddr, dataAddr);
break;
}
case DECIMAL64: {
argument[row - rowStart] = UdfConvert
.convertArrayDecimalArg(argTypes[argIdx].getScale(), 8L, row, currentRowNum,
offsetStart, isNullable, nullMapAddr, nestedNullMapAddr, dataAddr);
break;
}
default: {
LOG.info("Not support: " + argTypes[argIdx]);
Preconditions.checkState(false, "Not support type " + argTypes[argIdx].toString());
break;
}
}
}
return argument;
}
}
Loading