From a7c806842f57cd411e3c92493dc1bb05de113d56 Mon Sep 17 00:00:00 2001 From: Mryange <59914473+Mryange@users.noreply.github.com> Date: Wed, 2 Aug 2023 15:03:44 +0800 Subject: [PATCH 1/2] [vectorized](udaf) java udaf support with map type (#22397) [vectorized](udaf) java udaf support with map type (#22397) * test * remove some unused * update * add case --- .../aggregate_function_java_udaf.h | 313 ++++++----- be/src/vec/functions/function_java_udf.cpp | 18 - be/src/vec/functions/function_java_udf.h | 33 +- .../org/apache/doris/udf/BaseExecutor.java | 500 ++++++++++++++++++ .../org/apache/doris/udf/UdafExecutor.java | 21 +- .../org/apache/doris/udf/UdfExecutor.java | 494 +---------------- .../data/javaudf_p0/test_javaudf_agg_map.out | 7 + .../org/apache/doris/udf/MySumMapInt.java | 64 +++ .../org/apache/doris/udf/MySumMapIntDou.java | 64 +++ .../javaudf_p0/test_javaudf_agg_map.groovy | 78 +++ .../javaudf_p0/test_javaudf_ret_map.groovy | 6 + 11 files changed, 915 insertions(+), 683 deletions(-) create mode 100644 regression-test/data/javaudf_p0/test_javaudf_agg_map.out create mode 100644 regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumMapInt.java create mode 100644 regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumMapIntDou.java create mode 100644 regression-test/suites/javaudf_p0/test_javaudf_agg_map.groovy diff --git a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h index 6fe47420649c9c..d51c219f3f108f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h +++ b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h @@ -31,6 +31,7 @@ #include "util/jni-util.h" #include "vec/aggregate_functions/aggregate_function.h" #include "vec/columns/column_array.h" +#include "vec/columns/column_map.h" #include "vec/columns/column_string.h" #include "vec/common/string_ref.h" #include "vec/core/field.h" @@ -56,12 +57,6 @@ struct AggregateJavaUdafData { AggregateJavaUdafData() = default; AggregateJavaUdafData(int64_t num_args) { argument_size = num_args; - input_values_buffer_ptr = std::make_unique(num_args); - input_nulls_buffer_ptr = std::make_unique(num_args); - input_offsets_ptrs = std::make_unique(num_args); - input_array_nulls_buffer_ptr = std::make_unique(num_args); - input_array_string_offsets_ptrs = std::make_unique(num_args); - input_place_ptrs = std::make_unique(0); output_value_buffer = std::make_unique(0); output_null_value = std::make_unique(0); output_offsets_ptr = std::make_unique(0); @@ -93,16 +88,8 @@ struct AggregateJavaUdafData { TJavaUdfExecutorCtorParams ctor_params; ctor_params.__set_fn(fn); ctor_params.__set_location(local_location); - ctor_params.__set_input_offsets_ptrs((int64_t)input_offsets_ptrs.get()); - ctor_params.__set_input_buffer_ptrs((int64_t)input_values_buffer_ptr.get()); - ctor_params.__set_input_nulls_ptrs((int64_t)input_nulls_buffer_ptr.get()); - ctor_params.__set_input_array_nulls_buffer_ptr( - (int64_t)input_array_nulls_buffer_ptr.get()); - ctor_params.__set_input_array_string_offsets_ptrs( - (int64_t)input_array_string_offsets_ptrs.get()); ctor_params.__set_output_buffer_ptr((int64_t)output_value_buffer.get()); - ctor_params.__set_input_places_ptr((int64_t)input_place_ptrs.get()); ctor_params.__set_output_null_ptr((int64_t)output_null_value.get()); ctor_params.__set_output_offsets_ptr((int64_t)output_offsets_ptr.get()); @@ -188,6 +175,57 @@ struct AggregateJavaUdafData { arg_column_nullable, row_num_start, row_num_end, nullmap_address, offset_address, nested_nullmap_address, nested_data_address, nested_offset_address); + } else if (data_col->is_column_map()) { + const ColumnMap* map_col = assert_cast(data_col); + auto offset_address = reinterpret_cast( + map_col->get_offsets_column().get_raw_data().data); + const ColumnNullable& map_key_column_nullable = + assert_cast(map_col->get_keys()); + auto key_data_column_null_map = map_key_column_nullable.get_null_map_column_ptr(); + auto key_data_column = map_key_column_nullable.get_nested_column_ptr(); + + auto key_nested_nullmap_address = reinterpret_cast( + check_and_get_column>(key_data_column_null_map) + ->get_data() + .data()); + int64_t key_nested_data_address = 0, key_nested_offset_address = 0; + if (key_data_column->is_column_string()) { + const ColumnString* col = + assert_cast(key_data_column.get()); + key_nested_data_address = reinterpret_cast(col->get_chars().data()); + key_nested_offset_address = + reinterpret_cast(col->get_offsets().data()); + } else { + key_nested_data_address = + reinterpret_cast(key_data_column->get_raw_data().data); + } + + const ColumnNullable& map_value_column_nullable = + assert_cast(map_col->get_values()); + auto value_data_column_null_map = + map_value_column_nullable.get_null_map_column_ptr(); + auto value_data_column = map_value_column_nullable.get_nested_column_ptr(); + auto value_nested_nullmap_address = reinterpret_cast( + check_and_get_column>(value_data_column_null_map) + ->get_data() + .data()); + int64_t value_nested_data_address = 0, value_nested_offset_address = 0; + if (value_data_column->is_column_string()) { + const ColumnString* col = + assert_cast(value_data_column.get()); + value_nested_data_address = reinterpret_cast(col->get_chars().data()); + value_nested_offset_address = + reinterpret_cast(col->get_offsets().data()); + } else { + value_nested_data_address = + reinterpret_cast(value_data_column->get_raw_data().data); + } + arr_obj = (jobjectArray)env->CallObjectMethod( + executor_obj, executor_convert_map_argument_id, arg_idx, + arg_column_nullable, row_num_start, row_num_end, nullmap_address, + offset_address, key_nested_nullmap_address, key_nested_data_address, + key_nested_offset_address, value_nested_nullmap_address, + value_nested_data_address, value_nested_offset_address); } else { return Status::InvalidArgument( strings::Substitute("Java UDAF doesn't support type is $0 now !", @@ -262,131 +300,133 @@ struct AggregateJavaUdafData { *output_null_value = reinterpret_cast(nullable.get_null_map_column().get_raw_data().data); auto& data_col = nullable.get_nested_column(); - -#ifndef EVALUATE_JAVA_UDAF -#define EVALUATE_JAVA_UDAF \ - if (data_col.is_column_string()) { \ - const ColumnString* str_col = check_and_get_column(data_col); \ - ColumnString::Chars& chars = const_cast(str_col->get_chars()); \ - ColumnString::Offsets& offsets = \ - const_cast(str_col->get_offsets()); \ - int increase_buffer_size = 0; \ - int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \ - chars.resize(buffer_size); \ - *output_value_buffer = reinterpret_cast(chars.data()); \ - *output_offsets_ptr = reinterpret_cast(offsets.data()); \ - *output_intermediate_state_ptr = chars.size(); \ - jboolean res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, \ - executor_result_id, to.size() - 1, place); \ - while (res != JNI_TRUE) { \ - RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); \ - increase_buffer_size++; \ - buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \ - try { \ - chars.resize(buffer_size); \ - } catch (std::bad_alloc const& e) { \ - throw doris::Exception( \ - ErrorCode::INTERNAL_ERROR, \ - "memory allocate failed in column string, buffer:{},size:{},reason:{}", \ - increase_buffer_size, buffer_size, e.what()); \ - } \ - *output_value_buffer = reinterpret_cast(chars.data()); \ - *output_intermediate_state_ptr = chars.size(); \ - res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id, \ - to.size() - 1, place); \ - } \ - } else if (data_col.is_numeric() || data_col.is_column_decimal()) { \ - *output_value_buffer = reinterpret_cast(data_col.get_raw_data().data); \ - env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id, \ - to.size() - 1, place); \ - } else if (data_col.is_column_array()) { \ - ColumnArray& array_col = assert_cast(data_col); \ - ColumnNullable& array_nested_nullable = \ - assert_cast(array_col.get_data()); \ - auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr(); \ - auto data_column = array_nested_nullable.get_nested_column_ptr(); \ - auto& offset_column = array_col.get_offsets_column(); \ - int increase_buffer_size = 0; \ - int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \ - *output_offsets_ptr = reinterpret_cast(offset_column.get_raw_data().data); \ - data_column_null_map->resize(buffer_size); \ - auto& null_map_data = \ - assert_cast*>(data_column_null_map.get())->get_data(); \ - *output_array_null_ptr = reinterpret_cast(null_map_data.data()); \ - *output_intermediate_state_ptr = buffer_size; \ - if (data_column->is_column_string()) { \ - ColumnString* str_col = assert_cast(data_column.get()); \ - ColumnString::Chars& chars = assert_cast(str_col->get_chars()); \ - ColumnString::Offsets& offsets = \ - assert_cast(str_col->get_offsets()); \ - chars.resize(buffer_size); \ - offsets.resize(buffer_size); \ - *output_value_buffer = reinterpret_cast(chars.data()); \ - *output_array_string_offsets_ptr = reinterpret_cast(offsets.data()); \ - jboolean res = env->CallNonvirtualBooleanMethod( \ - executor_obj, executor_cl, executor_result_id, to.size() - 1, place); \ - while (res != JNI_TRUE) { \ - RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); \ - increase_buffer_size++; \ - buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \ - try { \ - null_map_data.resize(buffer_size); \ - chars.resize(buffer_size); \ - offsets.resize(buffer_size); \ - } catch (std::bad_alloc const& e) { \ - throw doris::Exception(ErrorCode::INTERNAL_ERROR, \ - "memory allocate failed in array column string, " \ - "buffer:{},size:{},reason:{}", \ - increase_buffer_size, buffer_size, e.what()); \ - } \ - *output_array_null_ptr = reinterpret_cast(null_map_data.data()); \ - *output_value_buffer = reinterpret_cast(chars.data()); \ - *output_array_string_offsets_ptr = reinterpret_cast(offsets.data()); \ - *output_intermediate_state_ptr = buffer_size; \ - res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, \ - executor_result_id, to.size() - 1, place); \ - } \ - } else { \ - data_column->resize(buffer_size); \ - *output_value_buffer = reinterpret_cast(data_column->get_raw_data().data); \ - jboolean res = env->CallNonvirtualBooleanMethod( \ - executor_obj, executor_cl, executor_result_id, to.size() - 1, place); \ - while (res != JNI_TRUE) { \ - RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); \ - increase_buffer_size++; \ - buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \ - try { \ - null_map_data.resize(buffer_size); \ - data_column->resize(buffer_size); \ - } catch (std::bad_alloc const& e) { \ - throw doris::Exception(ErrorCode::INTERNAL_ERROR, \ - "memory allocate failed in array number column, " \ - "buffer:{},size:{},reason:{}", \ - increase_buffer_size, buffer_size, e.what()); \ - } \ - *output_array_null_ptr = reinterpret_cast(null_map_data.data()); \ - *output_value_buffer = \ - reinterpret_cast(data_column->get_raw_data().data); \ - *output_intermediate_state_ptr = buffer_size; \ - res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, \ - executor_result_id, to.size() - 1, place); \ - } \ - } \ - } else { \ - return Status::InvalidArgument(strings::Substitute( \ - "Java UDAF doesn't support return type is $0 now !", result_type->get_name())); \ - } -#endif - EVALUATE_JAVA_UDAF; + RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col)); } else { *output_null_value = -1; auto& data_col = to; - EVALUATE_JAVA_UDAF; + RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col)); } return JniUtil::GetJniExceptionMsg(env); } private: + Status get_result(IColumn& to, const DataTypePtr& result_type, int64_t place, JNIEnv* env, + IColumn& data_col) const { + if (data_col.is_column_string()) { + const ColumnString* str_col = check_and_get_column(data_col); + ColumnString::Chars& chars = const_cast(str_col->get_chars()); + ColumnString::Offsets& offsets = + const_cast(str_col->get_offsets()); + int increase_buffer_size = 0; + int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); + chars.resize(buffer_size); + *output_value_buffer = reinterpret_cast(chars.data()); + *output_offsets_ptr = reinterpret_cast(offsets.data()); + *output_intermediate_state_ptr = chars.size(); + jboolean res = env->CallNonvirtualBooleanMethod( + executor_obj, executor_cl, executor_result_id, to.size() - 1, place); + while (res != JNI_TRUE) { + RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); + increase_buffer_size++; + buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); + try { + chars.resize(buffer_size); + } catch (std::bad_alloc const& e) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "memory allocate failed in column string, " + "buffer:{},size:{},reason:{}", + increase_buffer_size, buffer_size, e.what()); + } + *output_value_buffer = reinterpret_cast(chars.data()); + *output_intermediate_state_ptr = chars.size(); + res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, + executor_result_id, to.size() - 1, place); + } + } else if (data_col.is_numeric() || data_col.is_column_decimal()) { + *output_value_buffer = reinterpret_cast(data_col.get_raw_data().data); + env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id, + to.size() - 1, place); + } else if (data_col.is_column_array()) { + ColumnArray& array_col = assert_cast(data_col); + ColumnNullable& array_nested_nullable = + assert_cast(array_col.get_data()); + auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr(); + auto data_column = array_nested_nullable.get_nested_column_ptr(); + auto& offset_column = array_col.get_offsets_column(); + int increase_buffer_size = 0; + int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); + *output_offsets_ptr = reinterpret_cast(offset_column.get_raw_data().data); + data_column_null_map->resize(buffer_size); + auto& null_map_data = + assert_cast*>(data_column_null_map.get())->get_data(); + *output_array_null_ptr = reinterpret_cast(null_map_data.data()); + *output_intermediate_state_ptr = buffer_size; + if (data_column->is_column_string()) { + ColumnString* str_col = assert_cast(data_column.get()); + ColumnString::Chars& chars = + assert_cast(str_col->get_chars()); + ColumnString::Offsets& offsets = + assert_cast(str_col->get_offsets()); + chars.resize(buffer_size); + offsets.resize(buffer_size); + *output_value_buffer = reinterpret_cast(chars.data()); + *output_array_string_offsets_ptr = reinterpret_cast(offsets.data()); + jboolean res = env->CallNonvirtualBooleanMethod( + executor_obj, executor_cl, executor_result_id, to.size() - 1, place); + while (res != JNI_TRUE) { + RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); + increase_buffer_size++; + buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); + try { + null_map_data.resize(buffer_size); + chars.resize(buffer_size); + offsets.resize(buffer_size); + } catch (std::bad_alloc const& e) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "memory allocate failed in array column string, " + "buffer:{},size:{},reason:{}", + increase_buffer_size, buffer_size, e.what()); + } + *output_array_null_ptr = reinterpret_cast(null_map_data.data()); + *output_value_buffer = reinterpret_cast(chars.data()); + *output_array_string_offsets_ptr = reinterpret_cast(offsets.data()); + *output_intermediate_state_ptr = buffer_size; + res = env->CallNonvirtualBooleanMethod( + executor_obj, executor_cl, executor_result_id, to.size() - 1, place); + } + } else { + data_column->resize(buffer_size); + *output_value_buffer = reinterpret_cast(data_column->get_raw_data().data); + jboolean res = env->CallNonvirtualBooleanMethod( + executor_obj, executor_cl, executor_result_id, to.size() - 1, place); + while (res != JNI_TRUE) { + RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); + increase_buffer_size++; + buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); + try { + null_map_data.resize(buffer_size); + data_column->resize(buffer_size); + } catch (std::bad_alloc const& e) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "memory allocate failed in array number column, " + "buffer:{},size:{},reason:{}", + increase_buffer_size, buffer_size, e.what()); + } + *output_array_null_ptr = reinterpret_cast(null_map_data.data()); + *output_value_buffer = + reinterpret_cast(data_column->get_raw_data().data); + *output_intermediate_state_ptr = buffer_size; + res = env->CallNonvirtualBooleanMethod( + executor_obj, executor_cl, executor_result_id, to.size() - 1, place); + } + } + } else { + return Status::InvalidArgument(strings::Substitute( + "Java UDAF doesn't support return type is $0 now !", result_type->get_name())); + } + return Status::OK(); + } + Status register_func_id(JNIEnv* env) { auto register_id = [&](const char* func_name, const char* func_sign, jmethodID& func_id) { func_id = env->GetMethodID(executor_cl, func_name, func_sign); @@ -397,7 +437,6 @@ struct AggregateJavaUdafData { } return s; }; - RETURN_IF_ERROR(register_id("", UDAF_EXECUTOR_CTOR_SIGNATURE, executor_ctor_id)); RETURN_IF_ERROR(register_id("add", UDAF_EXECUTOR_ADD_SIGNATURE, executor_add_id)); RETURN_IF_ERROR(register_id("reset", UDAF_EXECUTOR_RESET_SIGNATURE, executor_reset_id)); @@ -413,6 +452,8 @@ struct AggregateJavaUdafData { executor_convert_basic_argument_id)); RETURN_IF_ERROR(register_id("convertArrayArguments", "(IZIIJJJJJ)[Ljava/lang/Object;", executor_convert_array_argument_id)); + RETURN_IF_ERROR(register_id("convertMapArguments", "(IZIIJJJJJJJJ)[Ljava/lang/Object;", + executor_convert_map_argument_id)); RETURN_IF_ERROR( register_id("addBatch", "(ZIIJI[Ljava/lang/Object;)V", executor_add_batch_id)); return Status::OK(); @@ -435,13 +476,7 @@ struct AggregateJavaUdafData { jmethodID executor_destroy_id; jmethodID executor_convert_basic_argument_id; jmethodID executor_convert_array_argument_id; - - std::unique_ptr input_values_buffer_ptr; - std::unique_ptr input_nulls_buffer_ptr; - std::unique_ptr input_offsets_ptrs; - std::unique_ptr input_array_nulls_buffer_ptr; - std::unique_ptr input_array_string_offsets_ptrs; - std::unique_ptr input_place_ptrs; + jmethodID executor_convert_map_argument_id; std::unique_ptr output_value_buffer; std::unique_ptr output_null_value; std::unique_ptr output_offsets_ptr; diff --git a/be/src/vec/functions/function_java_udf.cpp b/be/src/vec/functions/function_java_udf.cpp index 46bf887515249d..a2d41245517046 100644 --- a/be/src/vec/functions/function_java_udf.cpp +++ b/be/src/vec/functions/function_java_udf.cpp @@ -101,23 +101,6 @@ Status JavaFunctionCall::open(FunctionContext* context, FunctionContext::Functio TJavaUdfExecutorCtorParams ctor_params; ctor_params.__set_fn(fn_); ctor_params.__set_location(local_location); - ctor_params.__set_input_offsets_ptrs((int64_t)jni_ctx->input_offsets_ptrs.get()); - ctor_params.__set_input_buffer_ptrs((int64_t)jni_ctx->input_values_buffer_ptr.get()); - ctor_params.__set_input_nulls_ptrs((int64_t)jni_ctx->input_nulls_buffer_ptr.get()); - ctor_params.__set_input_array_nulls_buffer_ptr( - (int64_t)jni_ctx->input_array_nulls_buffer_ptr.get()); - ctor_params.__set_input_array_string_offsets_ptrs( - (int64_t)jni_ctx->input_array_string_offsets_ptrs.get()); - ctor_params.__set_output_buffer_ptr((int64_t)jni_ctx->output_value_buffer.get()); - ctor_params.__set_output_null_ptr((int64_t)jni_ctx->output_null_value.get()); - ctor_params.__set_output_offsets_ptr((int64_t)jni_ctx->output_offsets_ptr.get()); - ctor_params.__set_output_array_null_ptr((int64_t)jni_ctx->output_array_null_ptr.get()); - ctor_params.__set_output_array_string_offsets_ptr( - (int64_t)jni_ctx->output_array_string_offsets_ptr.get()); - ctor_params.__set_output_intermediate_state_ptr( - (int64_t)jni_ctx->output_intermediate_state_ptr.get()); - ctor_params.__set_batch_size_ptr((int64_t)jni_ctx->batch_size_ptr.get()); - jbyteArray ctor_params_bytes; // Pushed frame will be popped when jni_frame goes out-of-scope. @@ -255,7 +238,6 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, ->get_data() .data()); int64_t value_nested_data_address = 0, value_nested_offset_address = 0; - // array type need pass address: [nullmap_address], offset_address, nested_nullmap_address, nested_data_address/nested_char_address,nested_offset_address if (value_data_column->is_column_string()) { const ColumnString* col = assert_cast(value_data_column.get()); value_nested_data_address = reinterpret_cast(col->get_chars().data()); diff --git a/be/src/vec/functions/function_java_udf.h b/be/src/vec/functions/function_java_udf.h index c0828a2a3f742f..ddbe300e89c454 100644 --- a/be/src/vec/functions/function_java_udf.h +++ b/be/src/vec/functions/function_java_udf.h @@ -115,39 +115,8 @@ class JavaFunctionCall : public IFunctionBase { jobject executor = nullptr; bool is_closed = false; - std::unique_ptr input_values_buffer_ptr; - std::unique_ptr input_nulls_buffer_ptr; - std::unique_ptr input_offsets_ptrs; - //used for array type nested column null map, because array nested column must be nullable - std::unique_ptr input_array_nulls_buffer_ptr; - //used for array type of nested string column offset, not the array column offset - std::unique_ptr input_array_string_offsets_ptrs; - std::unique_ptr output_value_buffer; - std::unique_ptr output_null_value; - std::unique_ptr output_offsets_ptr; - //used for array type nested column null map - std::unique_ptr output_array_null_ptr; - //used for array type of nested string column offset - std::unique_ptr output_array_string_offsets_ptr; - std::unique_ptr batch_size_ptr; - // intermediate_state includes two parts: reserved / used buffer size and rows - std::unique_ptr output_intermediate_state_ptr; - JniContext(int64_t num_args, jclass executor_cl, jmethodID executor_close_id) - : executor_cl_(executor_cl), - executor_close_id_(executor_close_id), - input_values_buffer_ptr(new int64_t[num_args]), - input_nulls_buffer_ptr(new int64_t[num_args]), - input_offsets_ptrs(new int64_t[num_args]), - input_array_nulls_buffer_ptr(new int64_t[num_args]), - input_array_string_offsets_ptrs(new int64_t[num_args]), - output_value_buffer(new int64_t()), - output_null_value(new int64_t()), - output_offsets_ptr(new int64_t()), - output_array_null_ptr(new int64_t()), - output_array_string_offsets_ptr(new int64_t()), - batch_size_ptr(new int32_t()), - output_intermediate_state_ptr(new IntermediateState()) {} + : executor_cl_(executor_cl), executor_close_id_(executor_close_id) {} void close() { if (is_closed) { diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java index df5026742d42a4..eae5270872cc3f 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java @@ -25,6 +25,7 @@ import org.apache.doris.common.jni.utils.UdfUtils.JavaUdfDataType; import org.apache.doris.thrift.TJavaUdfExecutorCtorParams; +import com.esotericsoftware.reflectasm.MethodAccess; import com.google.common.base.Preconditions; import org.apache.log4j.Logger; import org.apache.thrift.TDeserializer; @@ -33,6 +34,7 @@ import java.io.IOException; import java.lang.reflect.Array; +import java.lang.reflect.Method; import java.math.BigDecimal; import java.math.BigInteger; import java.math.RoundingMode; @@ -42,6 +44,8 @@ import java.time.LocalDateTime; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; public abstract class BaseExecutor { private static final Logger LOG = Logger.getLogger(BaseExecutor.class); @@ -88,12 +92,14 @@ public abstract class BaseExecutor { protected final long outputArrayStringOffsetsPtr; protected final long outputIntermediateStatePtr; protected Class[] argClass; + protected MethodAccess methodAccess; /** * Create a UdfExecutor, using parameters from a serialized thrift object. Used * by * the backend. */ + public BaseExecutor(byte[] thriftParams) throws Exception { TJavaUdfExecutorCtorParams request = new TJavaUdfExecutorCtorParams(); TDeserializer deserializer = new TDeserializer(PROTOCOL_FACTORY); @@ -1320,4 +1326,498 @@ public Object[] convertMapArg(PrimitiveType type, int argIdx, boolean isNullable } return argument; } + + public Object[] buildHashMap(PrimitiveType keyType, PrimitiveType valueType, Object[] keyCol, Object[] valueCol) { + switch (keyType) { + case BOOLEAN: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case TINYINT: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case SMALLINT: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case INT: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case BIGINT: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case LARGEINT: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case FLOAT: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case DOUBLE: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case CHAR: + case VARCHAR: + case STRING: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case DATEV2: + case DATE: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case DATETIMEV2: + case DATETIME: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case DECIMAL32: + case DECIMAL64: + case DECIMALV2: + case DECIMAL128: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + default: { + LOG.info("Not support: " + keyType); + Preconditions.checkState(false, "Not support type " + keyType.toString()); + break; + } + } + return null; + } + + public static class HashMapBuilder { + public Object[] get(Object[] keyCol, Object[] valueCol, PrimitiveType valueType) { + switch (valueType) { + case BOOLEAN: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case TINYINT: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case SMALLINT: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case INT: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case BIGINT: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case LARGEINT: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case FLOAT: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case DOUBLE: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case CHAR: + case VARCHAR: + case STRING: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case DATEV2: + case DATE: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case DATETIMEV2: + case DATETIME: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case DECIMAL32: + case DECIMAL64: + case DECIMALV2: + case DECIMAL128: { + return new BuildMapFromType().get(keyCol, valueCol); + } + default: { + LOG.info("Not support: " + valueType); + Preconditions.checkState(false, "Not support type " + valueType.toString()); + break; + } + } + return null; + } + } + + public static class BuildMapFromType { + public Object[] get(Object[] keyCol, Object[] valueCol) { + Object[] retHashMap = new HashMap[keyCol.length]; + for (int colIdx = 0; colIdx < keyCol.length; colIdx++) { + HashMap hashMap = new HashMap<>(); + ArrayList keys = (ArrayList) (keyCol[colIdx]); + ArrayList values = (ArrayList) (valueCol[colIdx]); + for (int i = 0; i < keys.size(); i++) { + T1 key = keys.get(i); + T2 value = values.get(i); + if (!hashMap.containsKey(key)) { + hashMap.put(key, value); + } + } + retHashMap[colIdx] = hashMap; + } + return retHashMap; + } + } + + public void copyBatchBasicResultImpl(boolean isNullable, int numRows, Object[] result, long nullMapAddr, + long resColumnAddr, long strOffsetAddr, Method method) { + switch (retType) { + case BOOLEAN: { + UdfConvert.copyBatchBooleanResult(isNullable, numRows, (Boolean[]) result, nullMapAddr, resColumnAddr); + break; + } + case TINYINT: { + UdfConvert.copyBatchTinyIntResult(isNullable, numRows, (Byte[]) result, nullMapAddr, resColumnAddr); + break; + } + case SMALLINT: { + UdfConvert.copyBatchSmallIntResult(isNullable, numRows, (Short[]) result, nullMapAddr, resColumnAddr); + break; + } + case INT: { + UdfConvert.copyBatchIntResult(isNullable, numRows, (Integer[]) result, nullMapAddr, resColumnAddr); + break; + } + case BIGINT: { + UdfConvert.copyBatchBigIntResult(isNullable, numRows, (Long[]) result, nullMapAddr, resColumnAddr); + break; + } + case LARGEINT: { + UdfConvert.copyBatchLargeIntResult(isNullable, numRows, (BigInteger[]) result, nullMapAddr, + resColumnAddr); + break; + } + case FLOAT: { + UdfConvert.copyBatchFloatResult(isNullable, numRows, (Float[]) result, nullMapAddr, resColumnAddr); + break; + } + case DOUBLE: { + UdfConvert.copyBatchDoubleResult(isNullable, numRows, (Double[]) result, nullMapAddr, resColumnAddr); + break; + } + case CHAR: + case VARCHAR: + case STRING: { + UdfConvert.copyBatchStringResult(isNullable, numRows, (String[]) result, nullMapAddr, resColumnAddr, + strOffsetAddr); + break; + } + case DATE: { + UdfConvert.copyBatchDateResult(method.getReturnType(), isNullable, numRows, result, + nullMapAddr, resColumnAddr); + break; + } + case DATETIME: { + UdfConvert + .copyBatchDateTimeResult(method.getReturnType(), isNullable, numRows, result, + nullMapAddr, + resColumnAddr); + break; + } + case DATEV2: { + UdfConvert.copyBatchDateV2Result(method.getReturnType(), isNullable, numRows, result, + nullMapAddr, + resColumnAddr); + break; + } + case DATETIMEV2: { + UdfConvert.copyBatchDateTimeV2Result(method.getReturnType(), isNullable, numRows, + result, nullMapAddr, + resColumnAddr); + break; + } + case DECIMALV2: + case DECIMAL128: { + UdfConvert.copyBatchDecimal128Result(retType.getScale(), isNullable, numRows, (BigDecimal[]) result, + nullMapAddr, + resColumnAddr); + break; + } + case DECIMAL32: { + UdfConvert.copyBatchDecimal32Result(retType.getScale(), isNullable, numRows, (BigDecimal[]) result, + nullMapAddr, + resColumnAddr); + break; + } + case DECIMAL64: { + UdfConvert.copyBatchDecimal64Result(retType.getScale(), isNullable, numRows, (BigDecimal[]) result, + nullMapAddr, + resColumnAddr); + break; + } + default: { + LOG.info("Not support return type: " + retType); + Preconditions.checkState(false, "Not support type: " + retType.toString()); + break; + } + } + } + + public void copyBatchArrayResultImpl(boolean isNullable, int numRows, Object[] result, long nullMapAddr, + long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr, + PrimitiveType type) { + long hasPutElementNum = 0; + for (int row = 0; row < numRows; ++row) { + switch (type) { + case BOOLEAN: { + hasPutElementNum = UdfConvert + .copyBatchArrayBooleanResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case TINYINT: { + hasPutElementNum = UdfConvert + .copyBatchArrayTinyIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case SMALLINT: { + hasPutElementNum = UdfConvert + .copyBatchArraySmallIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case INT: { + hasPutElementNum = UdfConvert + .copyBatchArrayIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case BIGINT: { + hasPutElementNum = UdfConvert + .copyBatchArrayBigIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case LARGEINT: { + hasPutElementNum = UdfConvert + .copyBatchArrayLargeIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case FLOAT: { + hasPutElementNum = UdfConvert + .copyBatchArrayFloatResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DOUBLE: { + hasPutElementNum = UdfConvert + .copyBatchArrayDoubleResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case CHAR: + case VARCHAR: + case STRING: { + hasPutElementNum = UdfConvert + .copyBatchArrayStringResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr, strOffsetAddr); + break; + } + case DATE: { + hasPutElementNum = UdfConvert + .copyBatchArrayDateResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DATETIME: { + hasPutElementNum = UdfConvert + .copyBatchArrayDateTimeResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DATEV2: { + hasPutElementNum = UdfConvert + .copyBatchArrayDateV2Result(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DATETIMEV2: { + hasPutElementNum = UdfConvert + .copyBatchArrayDateTimeV2Result(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMALV2: { + hasPutElementNum = UdfConvert + .copyBatchArrayDecimalResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMAL32: { + hasPutElementNum = UdfConvert + .copyBatchArrayDecimalV3Result(retType.getScale(), 4L, hasPutElementNum, isNullable, row, + result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMAL64: { + hasPutElementNum = UdfConvert + .copyBatchArrayDecimalV3Result(retType.getScale(), 8L, hasPutElementNum, isNullable, row, + result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMAL128: { + hasPutElementNum = UdfConvert + .copyBatchArrayDecimalV3Result(retType.getScale(), 16L, hasPutElementNum, isNullable, row, + result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + default: { + Preconditions.checkState(false, "Not support type in array: " + retType); + break; + } + } + } + } + + public void buildArrayListFromHashMap(Object[] result, PrimitiveType keyType, PrimitiveType valueType, + Object[] keyCol, Object[] valueCol) { + switch (keyType) { + case BOOLEAN: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case TINYINT: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case SMALLINT: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case INT: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case BIGINT: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case LARGEINT: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case FLOAT: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case DOUBLE: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case CHAR: + case VARCHAR: + case STRING: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case DATEV2: + case DATE: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case DATETIMEV2: + case DATETIME: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case DECIMAL32: + case DECIMAL64: + case DECIMALV2: + case DECIMAL128: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + default: { + LOG.info("Not support: " + keyType); + Preconditions.checkState(false, "Not support type " + keyType.toString()); + break; + } + } + } + + public static class ArrayListBuilder { + public void get(Object[] map, Object[] keyCol, Object[] valueCol, PrimitiveType valueType) { + switch (valueType) { + case BOOLEAN: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case TINYINT: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case SMALLINT: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case INT: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case BIGINT: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case LARGEINT: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case FLOAT: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case DOUBLE: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case CHAR: + case VARCHAR: + case STRING: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case DATEV2: + case DATE: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case DATETIMEV2: + case DATETIME: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case DECIMAL32: + case DECIMAL64: + case DECIMALV2: + case DECIMAL128: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + default: { + LOG.info("Not support: " + valueType); + Preconditions.checkState(false, "Not support type " + valueType.toString()); + break; + } + } + } + } + + public static class BuildArrayFromType { + public void get(Object[] map, Object[] keyCol, Object[] valueCol) { + for (int colIdx = 0; colIdx < map.length; colIdx++) { + HashMap hashMap = (HashMap) map[colIdx]; + ArrayList keys = new ArrayList<>(); + ArrayList values = new ArrayList<>(); + for (Map.Entry entry : hashMap.entrySet()) { + keys.add(entry.getKey()); + values.add(entry.getValue()); + } + keyCol[colIdx] = keys; + valueCol[colIdx] = values; + } + } + } } diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java index a0736b5a72b0e5..dff689ed4038d6 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java @@ -17,6 +17,7 @@ package org.apache.doris.udf; +import org.apache.doris.catalog.PrimitiveType; import org.apache.doris.catalog.Type; import org.apache.doris.common.Pair; import org.apache.doris.common.exception.UdfRuntimeException; @@ -52,7 +53,6 @@ public class UdafExecutor extends BaseExecutor { private HashMap stateObjMap; private Class retClass; private int addIndex; - private MethodAccess methodAccess; /** * Constructor to create an object. @@ -81,6 +81,21 @@ public Object[] convertArrayArguments(int argIdx, boolean isNullable, int rowSta dataAddr, strOffsetAddr); } + public Object[] convertMapArguments(int argIdx, boolean isNullable, int rowStart, int rowEnd, long nullMapAddr, + long offsetsAddr, long keyNestedNullMapAddr, long keyDataAddr, long keyStrOffsetAddr, + long valueNestedNullMapAddr, long valueDataAddr, long valueStrOffsetAddr) { + PrimitiveType keyType = argTypes[argIdx].getKeyType().getPrimitiveType(); + PrimitiveType valueType = argTypes[argIdx].getValueType().getPrimitiveType(); + Object[] keyCol = convertMapArg(keyType, argIdx, isNullable, rowStart, rowEnd, nullMapAddr, offsetsAddr, + keyNestedNullMapAddr, keyDataAddr, + keyStrOffsetAddr); + Object[] valueCol = convertMapArg(valueType, argIdx, isNullable, rowStart, rowEnd, nullMapAddr, offsetsAddr, + valueNestedNullMapAddr, + valueDataAddr, + valueStrOffsetAddr); + return buildHashMap(keyType, valueType, keyCol, valueCol); + } + public void addBatch(boolean isSinglePlace, int rowStart, int rowEnd, long placeAddr, int offset, Object[] column) throws UdfRuntimeException { if (isSinglePlace) { @@ -111,7 +126,7 @@ public void addBatchSingle(int rowStart, int rowEnd, long placeAddr, Object[] co methodAccess.invoke(udf, addIndex, inputArgs); } } catch (Exception e) { - LOG.warn("invoke add function meet some error: " + e.getCause().toString()); + LOG.info("invoke add function meet some error: " + e.getCause().toString()); throw new UdfRuntimeException("UDAF failed to addBatchSingle: ", e); } } @@ -143,7 +158,7 @@ public void addBatchPlaces(int rowStart, int rowEnd, long placeAddr, int offset, methodAccess.invoke(udf, addIndex, inputArgs); } } catch (Exception e) { - LOG.warn("invoke add function meet some error: " + Arrays.toString(e.getStackTrace())); + LOG.info("invoke add function meet some error: " + Arrays.toString(e.getStackTrace())); throw new UdfRuntimeException("UDAF failed to addBatchPlaces: ", e); } } diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java index 1140d1824bae6f..2f6ca99fdd5120 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java @@ -34,14 +34,8 @@ import java.lang.reflect.Array; import java.lang.reflect.Constructor; import java.lang.reflect.Method; -import java.math.BigDecimal; -import java.math.BigInteger; import java.net.MalformedURLException; -import java.time.LocalDate; -import java.time.LocalDateTime; import java.util.ArrayList; -import java.util.HashMap; -import java.util.Map; public class UdfExecutor extends BaseExecutor { // private static final java.util.logging.Logger LOG = @@ -60,7 +54,6 @@ public class UdfExecutor extends BaseExecutor { private long batchSizePtr; private int evaluateIndex; - private MethodAccess methodAccess; /** * Create a UdfExecutor, using parameters from a serialized thrift object. Used by @@ -147,57 +140,7 @@ public Object[] convertMapArguments(int argIdx, boolean isNullable, int numRows, Object[] valueCol = convertMapArg(valueType, argIdx, isNullable, 0, numRows, nullMapAddr, offsetsAddr, valueNestedNullMapAddr, valueDataAddr, valueStrOffsetAddr); - switch (keyType) { - case BOOLEAN: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case TINYINT: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case SMALLINT: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case INT: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case BIGINT: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case LARGEINT: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case FLOAT: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case DOUBLE: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case CHAR: - case VARCHAR: - case STRING: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case DATEV2: - case DATE: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case DATETIMEV2: - case DATETIME: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case DECIMAL32: - case DECIMAL64: - case DECIMALV2: - case DECIMAL128: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - default: { - LOG.info("Not support: " + keyType); - Preconditions.checkState(false, "Not support type " + keyType.toString()); - break; - } - } - return null; + return buildHashMap(keyType, valueType, keyCol, valueCol); } /** @@ -223,217 +166,7 @@ public Object[] evaluate(int numRows, Object[] column) throws UdfRuntimeExceptio public void copyBatchBasicResult(boolean isNullable, int numRows, Object[] result, long nullMapAddr, long resColumnAddr, long strOffsetAddr) { - switch (retType) { - case BOOLEAN: { - UdfConvert.copyBatchBooleanResult(isNullable, numRows, (Boolean[]) result, nullMapAddr, resColumnAddr); - break; - } - case TINYINT: { - UdfConvert.copyBatchTinyIntResult(isNullable, numRows, (Byte[]) result, nullMapAddr, resColumnAddr); - break; - } - case SMALLINT: { - UdfConvert.copyBatchSmallIntResult(isNullable, numRows, (Short[]) result, nullMapAddr, resColumnAddr); - break; - } - case INT: { - UdfConvert.copyBatchIntResult(isNullable, numRows, (Integer[]) result, nullMapAddr, resColumnAddr); - break; - } - case BIGINT: { - UdfConvert.copyBatchBigIntResult(isNullable, numRows, (Long[]) result, nullMapAddr, resColumnAddr); - break; - } - case LARGEINT: { - UdfConvert.copyBatchLargeIntResult(isNullable, numRows, (BigInteger[]) result, nullMapAddr, - resColumnAddr); - break; - } - case FLOAT: { - UdfConvert.copyBatchFloatResult(isNullable, numRows, (Float[]) result, nullMapAddr, resColumnAddr); - break; - } - case DOUBLE: { - UdfConvert.copyBatchDoubleResult(isNullable, numRows, (Double[]) result, nullMapAddr, resColumnAddr); - break; - } - case CHAR: - case VARCHAR: - case STRING: { - UdfConvert.copyBatchStringResult(isNullable, numRows, (String[]) result, nullMapAddr, resColumnAddr, - strOffsetAddr); - break; - } - case DATE: { - UdfConvert.copyBatchDateResult(method.getReturnType(), isNullable, numRows, result, - nullMapAddr, resColumnAddr); - break; - } - case DATETIME: { - UdfConvert - .copyBatchDateTimeResult(method.getReturnType(), isNullable, numRows, result, - nullMapAddr, - resColumnAddr); - break; - } - case DATEV2: { - UdfConvert.copyBatchDateV2Result(method.getReturnType(), isNullable, numRows, result, - nullMapAddr, - resColumnAddr); - break; - } - case DATETIMEV2: { - UdfConvert.copyBatchDateTimeV2Result(method.getReturnType(), isNullable, numRows, - result, nullMapAddr, - resColumnAddr); - break; - } - case DECIMALV2: - case DECIMAL128: { - UdfConvert.copyBatchDecimal128Result(retType.getScale(), isNullable, numRows, (BigDecimal[]) result, - nullMapAddr, - resColumnAddr); - break; - } - case DECIMAL32: { - UdfConvert.copyBatchDecimal32Result(retType.getScale(), isNullable, numRows, (BigDecimal[]) result, - nullMapAddr, - resColumnAddr); - break; - } - case DECIMAL64: { - UdfConvert.copyBatchDecimal64Result(retType.getScale(), isNullable, numRows, (BigDecimal[]) result, - nullMapAddr, - resColumnAddr); - break; - } - default: { - LOG.info("Not support return type: " + retType); - Preconditions.checkState(false, "Not support type: " + retType.toString()); - break; - } - } - } - - public void copyBatchArrayResultImpl(boolean isNullable, int numRows, Object[] result, long nullMapAddr, - long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr, - PrimitiveType type) { - long hasPutElementNum = 0; - for (int row = 0; row < numRows; ++row) { - switch (type) { - case BOOLEAN: { - hasPutElementNum = UdfConvert - .copyBatchArrayBooleanResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case TINYINT: { - hasPutElementNum = UdfConvert - .copyBatchArrayTinyIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case SMALLINT: { - hasPutElementNum = UdfConvert - .copyBatchArraySmallIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case INT: { - hasPutElementNum = UdfConvert - .copyBatchArrayIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case BIGINT: { - hasPutElementNum = UdfConvert - .copyBatchArrayBigIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case LARGEINT: { - hasPutElementNum = UdfConvert - .copyBatchArrayLargeIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case FLOAT: { - hasPutElementNum = UdfConvert - .copyBatchArrayFloatResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DOUBLE: { - hasPutElementNum = UdfConvert - .copyBatchArrayDoubleResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case CHAR: - case VARCHAR: - case STRING: { - hasPutElementNum = UdfConvert - .copyBatchArrayStringResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr, strOffsetAddr); - break; - } - case DATE: { - hasPutElementNum = UdfConvert - .copyBatchArrayDateResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DATETIME: { - hasPutElementNum = UdfConvert - .copyBatchArrayDateTimeResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DATEV2: { - hasPutElementNum = UdfConvert - .copyBatchArrayDateV2Result(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DATETIMEV2: { - hasPutElementNum = UdfConvert - .copyBatchArrayDateTimeV2Result(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DECIMALV2: { - hasPutElementNum = UdfConvert - .copyBatchArrayDecimalResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DECIMAL32: { - hasPutElementNum = UdfConvert - .copyBatchArrayDecimalV3Result(retType.getScale(), 4L, hasPutElementNum, isNullable, row, - result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DECIMAL64: { - hasPutElementNum = UdfConvert - .copyBatchArrayDecimalV3Result(retType.getScale(), 8L, hasPutElementNum, isNullable, row, - result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DECIMAL128: { - hasPutElementNum = UdfConvert - .copyBatchArrayDecimalV3Result(retType.getScale(), 16L, hasPutElementNum, isNullable, row, - result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - default: { - Preconditions.checkState(false, "Not support type in array: " + retType); - break; - } - } - } + copyBatchBasicResultImpl(isNullable, numRows, result, nullMapAddr, resColumnAddr, strOffsetAddr, getMethod()); } public void copyBatchArrayResult(boolean isNullable, int numRows, Object[] result, long nullMapAddr, @@ -453,68 +186,7 @@ public void copyBatchMapResult(boolean isNullable, int numRows, Object[] result, PrimitiveType valueType = retType.getValueType().getPrimitiveType(); Object[] keyCol = new Object[result.length]; Object[] valueCol = new Object[result.length]; - switch (keyType) { - case BOOLEAN: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case TINYINT: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case SMALLINT: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case INT: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case BIGINT: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case LARGEINT: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case FLOAT: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case DOUBLE: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case CHAR: - case VARCHAR: - case STRING: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case DATEV2: - case DATE: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case DATETIMEV2: - case DATETIME: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case DECIMAL32: - case DECIMAL64: - case DECIMALV2: - case DECIMAL128: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - default: { - LOG.info("Not support: " + keyType); - Preconditions.checkState(false, "Not support type " + keyType.toString()); - break; - } - } + buildArrayListFromHashMap(result, keyType, valueType, keyCol, valueCol); copyBatchArrayResultImpl(isNullable, numRows, valueCol, nullMapAddr, offsetsAddr, valueNsestedNullMapAddr, valueDataAddr, @@ -522,7 +194,6 @@ public void copyBatchMapResult(boolean isNullable, int numRows, Object[] result, copyBatchArrayResultImpl(isNullable, numRows, keyCol, nullMapAddr, offsetsAddr, keyNsestedNullMapAddr, keyDataAddr, keyStrOffsetAddr, keyType); - } /** @@ -669,163 +340,4 @@ protected void init(TJavaUdfExecutorCtorParams request, String jarPath, Type fun } } - public static class HashMapBuilder { - public Object[] get(Object[] keyCol, Object[] valueCol, PrimitiveType valueType) { - switch (valueType) { - case BOOLEAN: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case TINYINT: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case SMALLINT: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case INT: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case BIGINT: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case LARGEINT: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case FLOAT: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case DOUBLE: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case CHAR: - case VARCHAR: - case STRING: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case DATEV2: - case DATE: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case DATETIMEV2: - case DATETIME: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case DECIMAL32: - case DECIMAL64: - case DECIMALV2: - case DECIMAL128: { - return new BuildMapFromType().get(keyCol, valueCol); - } - default: { - LOG.info("Not support: " + valueType); - Preconditions.checkState(false, "Not support type " + valueType.toString()); - break; - } - } - return null; - } - } - - public static class BuildMapFromType { - public Object[] get(Object[] keyCol, Object[] valueCol) { - Object[] retHashMap = new HashMap[keyCol.length]; - for (int colIdx = 0; colIdx < keyCol.length; colIdx++) { - HashMap hashMap = new HashMap<>(); - ArrayList keys = (ArrayList) (keyCol[colIdx]); - ArrayList values = (ArrayList) (valueCol[colIdx]); - for (int i = 0; i < keys.size(); i++) { - T1 key = keys.get(i); - T2 value = values.get(i); - if (!hashMap.containsKey(key)) { - hashMap.put(key, value); - } - } - retHashMap[colIdx] = hashMap; - } - return retHashMap; - } - } - - public static class ArrayListBuilder { - public void get(Object[] map, Object[] keyCol, Object[] valueCol, PrimitiveType valueType) { - switch (valueType) { - case BOOLEAN: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case TINYINT: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case SMALLINT: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case INT: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case BIGINT: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case LARGEINT: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case FLOAT: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case DOUBLE: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case CHAR: - case VARCHAR: - case STRING: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case DATEV2: - case DATE: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case DATETIMEV2: - case DATETIME: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case DECIMAL32: - case DECIMAL64: - case DECIMALV2: - case DECIMAL128: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - default: { - LOG.info("Not support: " + valueType); - Preconditions.checkState(false, "Not support type " + valueType.toString()); - break; - } - } - } - } - - public static class BuildArrayFromType { - public void get(Object[] map, Object[] keyCol, Object[] valueCol) { - for (int colIdx = 0; colIdx < map.length; colIdx++) { - HashMap hashMap = (HashMap) map[colIdx]; - ArrayList keys = new ArrayList<>(); - ArrayList values = new ArrayList<>(); - for (Map.Entry entry : hashMap.entrySet()) { - keys.add(entry.getKey()); - values.add(entry.getValue()); - } - keyCol[colIdx] = keys; - valueCol[colIdx] = values; - } - } - } - } diff --git a/regression-test/data/javaudf_p0/test_javaudf_agg_map.out b/regression-test/data/javaudf_p0/test_javaudf_agg_map.out new file mode 100644 index 00000000000000..b4093b461ca210 --- /dev/null +++ b/regression-test/data/javaudf_p0/test_javaudf_agg_map.out @@ -0,0 +1,7 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !select_1 -- +616.0 + +-- !select_2 -- +342 + diff --git a/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumMapInt.java b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumMapInt.java new file mode 100644 index 00000000000000..6310355ccb8cba --- /dev/null +++ b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumMapInt.java @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +package org.apache.doris.udf; +import org.apache.log4j.Logger; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +public class MySumMapInt { + private static final Logger LOG = Logger.getLogger(MySumMapInt.class); + public static class State { + public long counter = 0; + } + + public State create() { + return new State(); + } + + public void destroy(State state) { + } + + public void add(State state, HashMap val) { + if (val == null) { + return; + } + for(Map.Entry it : val.entrySet()){ + Integer key = it.getKey(); + Integer value = it.getValue(); + state.counter += key + value; + } + } + + public void serialize(State state, DataOutputStream out) throws IOException { + out.writeLong(state.counter); + } + + public void deserialize(State state, DataInputStream in) throws IOException { + state.counter = in.readLong(); + } + + public void merge(State state, State rhs) { + state.counter += rhs.counter; + } + + public long getValue(State state) { + return state.counter; + } +} \ No newline at end of file diff --git a/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumMapIntDou.java b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumMapIntDou.java new file mode 100644 index 00000000000000..7690bba6037443 --- /dev/null +++ b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumMapIntDou.java @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +package org.apache.doris.udf; +import org.apache.log4j.Logger; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +public class MySumMapIntDou { + private static final Logger LOG = Logger.getLogger(MySumMapIntDou.class); + public static class State { + public Double counter = 0.0; + } + + public State create() { + return new State(); + } + + public void destroy(State state) { + } + + public void add(State state, HashMap val) { + if (val == null) { + return; + } + for(Map.Entry it : val.entrySet()){ + Integer key = it.getKey(); + Double value = it.getValue(); + state.counter += key * value; + } + } + + public void serialize(State state, DataOutputStream out) throws IOException { + out.writeDouble(state.counter); + } + + public void deserialize(State state, DataInputStream in) throws IOException { + state.counter = in.readDouble(); + } + + public void merge(State state, State rhs) { + state.counter += rhs.counter; + } + + public Double getValue(State state) { + return state.counter; + } +} \ No newline at end of file diff --git a/regression-test/suites/javaudf_p0/test_javaudf_agg_map.groovy b/regression-test/suites/javaudf_p0/test_javaudf_agg_map.groovy new file mode 100644 index 00000000000000..facd8fe1f9c3c5 --- /dev/null +++ b/regression-test/suites/javaudf_p0/test_javaudf_agg_map.groovy @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import org.codehaus.groovy.runtime.IOGroovyMethods + +import java.nio.charset.StandardCharsets +import java.nio.file.Files +import java.nio.file.Paths + +suite("test_javaudf_agg_map") { + def jarPath = """${context.file.parent}/jars/java-udf-case-jar-with-dependencies.jar""" + log.info("Jar path: ${jarPath}".toString()) + try { + try_sql("DROP FUNCTION IF EXISTS mapii(Map);") + try_sql("DROP FUNCTION IF EXISTS mapid(Map);") + try_sql("DROP TABLE IF EXISTS db") + sql """ + CREATE TABLE IF NOT EXISTS db( + `id` INT NULL COMMENT "", + `i` INT NULL COMMENT "", + `d` Double NULL COMMENT "", + `mii` Map NULL COMMENT "", + `mid` Map NULL COMMENT "" + ) ENGINE=OLAP + DUPLICATE KEY(`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "storage_format" = "V2"); + """ + sql """ INSERT INTO db VALUES(1, 10,1.1,{1:1,10:1,100:1},{1:1.1,11:11.1}); """ + sql """ INSERT INTO db VALUES(2, 20,2.2,{2:2,20:2,200:2},{2:2.2,22:22.2}); """ + + sql """ + + CREATE AGGREGATE FUNCTION mapii(Map) RETURNS BigInt PROPERTIES ( + "file"="file://${jarPath}", + "symbol"="org.apache.doris.udf.MySumMapInt", + "type"="JAVA_UDF" + ); + + """ + + sql """ + + CREATE AGGREGATE FUNCTION mapid(Map) RETURNS Double PROPERTIES ( + "file"="file://${jarPath}", + "symbol"="org.apache.doris.udf.MySumMapIntDou", + "type"="JAVA_UDF" + ); + + """ + + + qt_select_1 """ select mapid(mid) from db; """ + + qt_select_2 """ select mapii(mii) from db; """ + + } finally { + try_sql("DROP FUNCTION IF EXISTS mapii(Map);") + try_sql("DROP FUNCTION IF EXISTS mapid(Map);") + try_sql("DROP TABLE IF EXISTS db") + } +} diff --git a/regression-test/suites/javaudf_p0/test_javaudf_ret_map.groovy b/regression-test/suites/javaudf_p0/test_javaudf_ret_map.groovy index df8baa37f98343..8421a6699a2875 100644 --- a/regression-test/suites/javaudf_p0/test_javaudf_ret_map.groovy +++ b/regression-test/suites/javaudf_p0/test_javaudf_ret_map.groovy @@ -25,6 +25,12 @@ suite("test_javaudf_ret_map") { def jarPath = """${context.file.parent}/jars/java-udf-case-jar-with-dependencies.jar""" log.info("Jar path: ${jarPath}".toString()) try { + try_sql("DROP FUNCTION IF EXISTS retii(map);") + try_sql("DROP FUNCTION IF EXISTS retss(map);") + try_sql("DROP FUNCTION IF EXISTS retid(map);") + try_sql("DROP FUNCTION IF EXISTS retidss(int ,double);") + try_sql("DROP TABLE IF EXISTS db") + try_sql("DROP TABLE IF EXISTS dbss") sql """ CREATE TABLE IF NOT EXISTS db( `id` INT NULL COMMENT "", From 9337029d1aad78e594a95ea80d1b8c3ee793b581 Mon Sep 17 00:00:00 2001 From: Mryange <59914473+Mryange@users.noreply.github.com> Date: Wed, 9 Aug 2023 22:44:07 +0800 Subject: [PATCH 2/2] [refactor](udaf) refactor call udaf function and support map type in return (#22508) --- .../aggregate_function_java_udaf.h | 246 ++-- be/src/vec/functions/function_java_udf.cpp | 44 +- .../org/apache/doris/udf/BaseExecutor.java | 1091 +++-------------- .../org/apache/doris/udf/UdafExecutor.java | 104 +- .../java/org/apache/doris/udf/UdfConvert.java | 80 +- .../org/apache/doris/udf/UdfExecutor.java | 68 - .../org/apache/doris/udf/UdfExecutorTest.java | 600 --------- .../javaudf_p0/test_javaudaf_return_map.out | 31 + .../apache/doris/udf/MyReturnMapString.java | 75 ++ .../apache/doris/udf/MySumReturnMapInt.java | 73 ++ .../doris/udf/MySumReturnMapIntDou.java | 74 ++ .../test_javaudaf_return_map.groovy | 104 ++ 12 files changed, 760 insertions(+), 1830 deletions(-) delete mode 100644 fe/be-java-extensions/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java create mode 100644 regression-test/data/javaudf_p0/test_javaudaf_return_map.out create mode 100644 regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MyReturnMapString.java create mode 100644 regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumReturnMapInt.java create mode 100644 regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumReturnMapIntDou.java create mode 100644 regression-test/suites/javaudf_p0/test_javaudaf_return_map.groovy diff --git a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h index d51c219f3f108f..defd33b47546a2 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h +++ b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h @@ -25,6 +25,7 @@ #include "common/compiler_util.h" #include "common/exception.h" +#include "common/logging.h" #include "common/status.h" #include "gutil/strings/substitute.h" #include "runtime/user_function_cache.h" @@ -55,15 +56,7 @@ const char* UDAF_EXECUTOR_RESET_SIGNATURE = "(J)V"; struct AggregateJavaUdafData { public: AggregateJavaUdafData() = default; - AggregateJavaUdafData(int64_t num_args) { - argument_size = num_args; - output_value_buffer = std::make_unique(0); - output_null_value = std::make_unique(0); - output_offsets_ptr = std::make_unique(0); - output_intermediate_state_ptr = std::make_unique(0); - output_array_null_ptr = std::make_unique(0); - output_array_string_offsets_ptr = std::make_unique(0); - } + AggregateJavaUdafData(int64_t num_args) { argument_size = num_args; } ~AggregateJavaUdafData() { JNIEnv* env; @@ -89,16 +82,6 @@ struct AggregateJavaUdafData { ctor_params.__set_fn(fn); ctor_params.__set_location(local_location); - ctor_params.__set_output_buffer_ptr((int64_t)output_value_buffer.get()); - - ctor_params.__set_output_null_ptr((int64_t)output_null_value.get()); - ctor_params.__set_output_offsets_ptr((int64_t)output_offsets_ptr.get()); - ctor_params.__set_output_intermediate_state_ptr( - (int64_t)output_intermediate_state_ptr.get()); - ctor_params.__set_output_array_null_ptr((int64_t)output_array_null_ptr.get()); - ctor_params.__set_output_array_string_offsets_ptr( - (int64_t)output_array_string_offsets_ptr.get()); - jbyteArray ctor_params_bytes; // Pushed frame will be popped when jni_frame goes out-of-scope. @@ -295,23 +278,27 @@ struct AggregateJavaUdafData { to.insert_default(); JNIEnv* env = nullptr; RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf get value function"); + int64_t nullmap_address = 0; if (result_type->is_nullable()) { auto& nullable = assert_cast(to); - *output_null_value = + nullmap_address = reinterpret_cast(nullable.get_null_map_column().get_raw_data().data); auto& data_col = nullable.get_nested_column(); - RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col)); + RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col, nullmap_address)); } else { - *output_null_value = -1; + nullmap_address = -1; auto& data_col = to; - RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col)); + RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col, nullmap_address)); } return JniUtil::GetJniExceptionMsg(env); } private: - Status get_result(IColumn& to, const DataTypePtr& result_type, int64_t place, JNIEnv* env, - IColumn& data_col) const { + Status get_result(IColumn& to, const DataTypePtr& return_type, int64_t place, JNIEnv* env, + IColumn& data_col, int64_t nullmap_address) const { + jobject result_obj = env->CallNonvirtualObjectMethod(executor_obj, executor_cl, + executor_get_value_id, place); + bool result_nullable = return_type->is_nullable(); if (data_col.is_column_string()) { const ColumnString* str_col = check_and_get_column(data_col); ColumnString::Chars& chars = const_cast(str_col->get_chars()); @@ -320,109 +307,119 @@ struct AggregateJavaUdafData { int increase_buffer_size = 0; int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); chars.resize(buffer_size); - *output_value_buffer = reinterpret_cast(chars.data()); - *output_offsets_ptr = reinterpret_cast(offsets.data()); - *output_intermediate_state_ptr = chars.size(); - jboolean res = env->CallNonvirtualBooleanMethod( - executor_obj, executor_cl, executor_result_id, to.size() - 1, place); - while (res != JNI_TRUE) { - RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); - increase_buffer_size++; - buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); - try { - chars.resize(buffer_size); - } catch (std::bad_alloc const& e) { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "memory allocate failed in column string, " - "buffer:{},size:{},reason:{}", - increase_buffer_size, buffer_size, e.what()); - } - *output_value_buffer = reinterpret_cast(chars.data()); - *output_intermediate_state_ptr = chars.size(); - res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, - executor_result_id, to.size() - 1, place); - } + env->CallNonvirtualVoidMethod( + executor_obj, executor_cl, executor_copy_basic_result_id, result_obj, + to.size() - 1, nullmap_address, reinterpret_cast(chars.data()), + reinterpret_cast(&chars), reinterpret_cast(offsets.data())); } else if (data_col.is_numeric() || data_col.is_column_decimal()) { - *output_value_buffer = reinterpret_cast(data_col.get_raw_data().data); - env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id, - to.size() - 1, place); + env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_copy_basic_result_id, + result_obj, to.size() - 1, nullmap_address, + reinterpret_cast(data_col.get_raw_data().data), + 0, 0); } else if (data_col.is_column_array()) { - ColumnArray& array_col = assert_cast(data_col); + jclass arraylist_class = env->FindClass("Ljava/util/ArrayList;"); + ColumnArray* array_col = assert_cast(&data_col); ColumnNullable& array_nested_nullable = - assert_cast(array_col.get_data()); + assert_cast(array_col->get_data()); auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr(); auto data_column = array_nested_nullable.get_nested_column_ptr(); - auto& offset_column = array_col.get_offsets_column(); - int increase_buffer_size = 0; - int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); - *output_offsets_ptr = reinterpret_cast(offset_column.get_raw_data().data); - data_column_null_map->resize(buffer_size); + auto& offset_column = array_col->get_offsets_column(); + auto offset_address = reinterpret_cast(offset_column.get_raw_data().data); auto& null_map_data = assert_cast*>(data_column_null_map.get())->get_data(); - *output_array_null_ptr = reinterpret_cast(null_map_data.data()); - *output_intermediate_state_ptr = buffer_size; + auto nested_nullmap_address = reinterpret_cast(null_map_data.data()); + jmethodID list_size = env->GetMethodID(arraylist_class, "size", "()I"); + + size_t has_put_element_size = array_col->get_offsets().back(); + size_t arrar_list_size = env->CallIntMethod(result_obj, list_size); + size_t element_size = has_put_element_size + arrar_list_size; + array_nested_nullable.resize(element_size); + memset(null_map_data.data() + has_put_element_size, 0, arrar_list_size); + int64_t nested_data_address = 0, nested_offset_address = 0; if (data_column->is_column_string()) { ColumnString* str_col = assert_cast(data_column.get()); ColumnString::Chars& chars = assert_cast(str_col->get_chars()); ColumnString::Offsets& offsets = assert_cast(str_col->get_offsets()); - chars.resize(buffer_size); - offsets.resize(buffer_size); - *output_value_buffer = reinterpret_cast(chars.data()); - *output_array_string_offsets_ptr = reinterpret_cast(offsets.data()); - jboolean res = env->CallNonvirtualBooleanMethod( - executor_obj, executor_cl, executor_result_id, to.size() - 1, place); - while (res != JNI_TRUE) { - RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); - increase_buffer_size++; - buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); - try { - null_map_data.resize(buffer_size); - chars.resize(buffer_size); - offsets.resize(buffer_size); - } catch (std::bad_alloc const& e) { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "memory allocate failed in array column string, " - "buffer:{},size:{},reason:{}", - increase_buffer_size, buffer_size, e.what()); - } - *output_array_null_ptr = reinterpret_cast(null_map_data.data()); - *output_value_buffer = reinterpret_cast(chars.data()); - *output_array_string_offsets_ptr = reinterpret_cast(offsets.data()); - *output_intermediate_state_ptr = buffer_size; - res = env->CallNonvirtualBooleanMethod( - executor_obj, executor_cl, executor_result_id, to.size() - 1, place); - } + nested_data_address = reinterpret_cast(&chars); + nested_offset_address = reinterpret_cast(offsets.data()); } else { - data_column->resize(buffer_size); - *output_value_buffer = reinterpret_cast(data_column->get_raw_data().data); - jboolean res = env->CallNonvirtualBooleanMethod( - executor_obj, executor_cl, executor_result_id, to.size() - 1, place); - while (res != JNI_TRUE) { - RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); - increase_buffer_size++; - buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); - try { - null_map_data.resize(buffer_size); - data_column->resize(buffer_size); - } catch (std::bad_alloc const& e) { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "memory allocate failed in array number column, " - "buffer:{},size:{},reason:{}", - increase_buffer_size, buffer_size, e.what()); - } - *output_array_null_ptr = reinterpret_cast(null_map_data.data()); - *output_value_buffer = - reinterpret_cast(data_column->get_raw_data().data); - *output_intermediate_state_ptr = buffer_size; - res = env->CallNonvirtualBooleanMethod( - executor_obj, executor_cl, executor_result_id, to.size() - 1, place); - } + nested_data_address = reinterpret_cast(data_column->get_raw_data().data); + } + int row = to.size() - 1; + env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_copy_array_result_id, + has_put_element_size, result_nullable, row, result_obj, + nullmap_address, offset_address, nested_nullmap_address, + nested_data_address, nested_offset_address); + env->DeleteLocalRef(arraylist_class); + } else if (data_col.is_column_map()) { + jclass hashmap_class = env->FindClass("Ljava/util/HashMap;"); + ColumnMap* map_col = assert_cast(&data_col); + auto& offset_column = map_col->get_offsets_column(); + auto offset_address = reinterpret_cast(offset_column.get_raw_data().data); + ColumnNullable& map_key_column_nullable = + assert_cast(map_col->get_keys()); + auto key_data_column_null_map = map_key_column_nullable.get_null_map_column_ptr(); + auto key_data_column = map_key_column_nullable.get_nested_column_ptr(); + auto& key_null_map_data = + assert_cast*>(key_data_column_null_map.get())->get_data(); + auto key_nested_nullmap_address = reinterpret_cast(key_null_map_data.data()); + ColumnNullable& map_value_column_nullable = + assert_cast(map_col->get_values()); + auto value_data_column_null_map = map_value_column_nullable.get_null_map_column_ptr(); + auto value_data_column = map_value_column_nullable.get_nested_column_ptr(); + auto& value_null_map_data = + assert_cast*>(value_data_column_null_map.get())->get_data(); + auto value_nested_nullmap_address = + reinterpret_cast(value_null_map_data.data()); + jmethodID map_size = env->GetMethodID(hashmap_class, "size", "()I"); + size_t has_put_element_size = map_col->get_offsets().back(); + size_t hashmap_size = env->CallIntMethod(result_obj, map_size); + size_t element_size = has_put_element_size + hashmap_size; + map_key_column_nullable.resize(element_size); + memset(key_null_map_data.data() + has_put_element_size, 0, hashmap_size); + map_value_column_nullable.resize(element_size); + memset(value_null_map_data.data() + has_put_element_size, 0, hashmap_size); + + int64_t key_nested_data_address = 0, key_nested_offset_address = 0; + if (key_data_column->is_column_string()) { + ColumnString* str_col = assert_cast(key_data_column.get()); + ColumnString::Chars& chars = + assert_cast(str_col->get_chars()); + ColumnString::Offsets& offsets = + assert_cast(str_col->get_offsets()); + key_nested_data_address = reinterpret_cast(&chars); + key_nested_offset_address = reinterpret_cast(offsets.data()); + } else { + key_nested_data_address = + reinterpret_cast(key_data_column->get_raw_data().data); } + + int64_t value_nested_data_address = 0, value_nested_offset_address = 0; + if (value_data_column->is_column_string()) { + ColumnString* str_col = assert_cast(value_data_column.get()); + ColumnString::Chars& chars = + assert_cast(str_col->get_chars()); + ColumnString::Offsets& offsets = + assert_cast(str_col->get_offsets()); + value_nested_data_address = reinterpret_cast(&chars); + value_nested_offset_address = reinterpret_cast(offsets.data()); + } else { + value_nested_data_address = + reinterpret_cast(value_data_column->get_raw_data().data); + } + int row = to.size() - 1; + env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_copy_map_result_id, + has_put_element_size, result_nullable, row, result_obj, + nullmap_address, offset_address, + key_nested_nullmap_address, key_nested_data_address, + key_nested_offset_address, value_nested_nullmap_address, + value_nested_data_address, value_nested_offset_address); + env->DeleteLocalRef(hashmap_class); } else { return Status::InvalidArgument(strings::Substitute( - "Java UDAF doesn't support return type is $0 now !", result_type->get_name())); + "Java UDAF doesn't support return type is $0 now !", return_type->get_name())); } return Status::OK(); } @@ -438,14 +435,12 @@ struct AggregateJavaUdafData { return s; }; RETURN_IF_ERROR(register_id("", UDAF_EXECUTOR_CTOR_SIGNATURE, executor_ctor_id)); - RETURN_IF_ERROR(register_id("add", UDAF_EXECUTOR_ADD_SIGNATURE, executor_add_id)); RETURN_IF_ERROR(register_id("reset", UDAF_EXECUTOR_RESET_SIGNATURE, executor_reset_id)); RETURN_IF_ERROR(register_id("close", UDAF_EXECUTOR_CLOSE_SIGNATURE, executor_close_id)); RETURN_IF_ERROR(register_id("merge", UDAF_EXECUTOR_MERGE_SIGNATURE, executor_merge_id)); RETURN_IF_ERROR( register_id("serialize", UDAF_EXECUTOR_SERIALIZE_SIGNATURE, executor_serialize_id)); - RETURN_IF_ERROR( - register_id("getValue", UDAF_EXECUTOR_RESULT_SIGNATURE, executor_result_id)); + RETURN_IF_ERROR(register_id("getValue", "(J)Ljava/lang/Object;", executor_get_value_id)); RETURN_IF_ERROR( register_id("destroy", UDAF_EXECUTOR_DESTROY_SIGNATURE, executor_destroy_id)); RETURN_IF_ERROR(register_id("convertBasicArguments", "(IZIIJJJ)[Ljava/lang/Object;", @@ -454,6 +449,16 @@ struct AggregateJavaUdafData { executor_convert_array_argument_id)); RETURN_IF_ERROR(register_id("convertMapArguments", "(IZIIJJJJJJJJ)[Ljava/lang/Object;", executor_convert_map_argument_id)); + + RETURN_IF_ERROR(register_id("copyTupleBasicResult", "(Ljava/lang/Object;IJJJJ)V", + executor_copy_basic_result_id)); + + RETURN_IF_ERROR(register_id("copyTupleArrayResult", "(JZILjava/lang/Object;JJJJJ)V", + executor_copy_array_result_id)); + + RETURN_IF_ERROR(register_id("copyTupleMapResult", "(JZILjava/lang/Object;JJJJJJJJ)V", + executor_copy_map_result_id)); + RETURN_IF_ERROR( register_id("addBatch", "(ZIIJI[Ljava/lang/Object;)V", executor_add_batch_id)); return Status::OK(); @@ -466,24 +471,19 @@ struct AggregateJavaUdafData { jobject executor_obj; jmethodID executor_ctor_id; - jmethodID executor_add_id; jmethodID executor_add_batch_id; jmethodID executor_merge_id; jmethodID executor_serialize_id; - jmethodID executor_result_id; + jmethodID executor_get_value_id; jmethodID executor_reset_id; jmethodID executor_close_id; jmethodID executor_destroy_id; jmethodID executor_convert_basic_argument_id; jmethodID executor_convert_array_argument_id; jmethodID executor_convert_map_argument_id; - std::unique_ptr output_value_buffer; - std::unique_ptr output_null_value; - std::unique_ptr output_offsets_ptr; - std::unique_ptr output_intermediate_state_ptr; - std::unique_ptr output_array_null_ptr; - std::unique_ptr output_array_string_offsets_ptr; - + jmethodID executor_copy_basic_result_id; + jmethodID executor_copy_array_result_id; + jmethodID executor_copy_map_result_id; int argument_size = 0; std::string serialize_data; }; diff --git a/be/src/vec/functions/function_java_udf.cpp b/be/src/vec/functions/function_java_udf.cpp index a2d41245517046..7c50e74117716b 100644 --- a/be/src/vec/functions/function_java_udf.cpp +++ b/be/src/vec/functions/function_java_udf.cpp @@ -346,6 +346,27 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, auto& key_null_map_data = assert_cast*>(key_data_column_null_map.get())->get_data(); auto key_nested_nullmap_address = reinterpret_cast(key_null_map_data.data()); + ColumnNullable& map_value_column_nullable = + assert_cast(map_col->get_values()); + auto value_data_column_null_map = map_value_column_nullable.get_null_map_column_ptr(); + auto value_data_column = map_value_column_nullable.get_nested_column_ptr(); + auto& value_null_map_data = + assert_cast*>(value_data_column_null_map.get())->get_data(); + auto value_nested_nullmap_address = reinterpret_cast(value_null_map_data.data()); + jmethodID map_size = env->GetMethodID(hashmap_class, "size", "()I"); + int element_size = 0; // get all element size in num_rows of map column + for (int i = 0; i < num_rows; ++i) { + jobject obj = env->GetObjectArrayElement(result_obj, i); + if (obj == nullptr) { + continue; + } + element_size = element_size + env->CallIntMethod(obj, map_size); + env->DeleteLocalRef(obj); + } + map_key_column_nullable.resize(element_size); + memset(key_null_map_data.data(), 0, element_size); + map_value_column_nullable.resize(element_size); + memset(value_null_map_data.data(), 0, element_size); int64_t key_nested_data_address = 0, key_nested_offset_address = 0; if (key_data_column->is_column_string()) { ColumnString* str_col = assert_cast(key_data_column.get()); @@ -358,16 +379,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, key_nested_data_address = reinterpret_cast(key_data_column->get_raw_data().data); } - - ColumnNullable& map_value_column_nullable = - assert_cast(map_col->get_values()); - auto value_data_column_null_map = map_value_column_nullable.get_null_map_column_ptr(); - auto value_data_column = map_value_column_nullable.get_nested_column_ptr(); - auto& value_null_map_data = - assert_cast*>(value_data_column_null_map.get())->get_data(); - auto value_nested_nullmap_address = reinterpret_cast(value_null_map_data.data()); int64_t value_nested_data_address = 0, value_nested_offset_address = 0; - // array type need pass address: [nullmap_address], offset_address, nested_nullmap_address, nested_data_address/nested_char_address,nested_offset_address if (value_data_column->is_column_string()) { ColumnString* str_col = assert_cast(value_data_column.get()); ColumnString::Chars& chars = assert_cast(str_col->get_chars()); @@ -379,20 +391,6 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, value_nested_data_address = reinterpret_cast(value_data_column->get_raw_data().data); } - jmethodID map_size = env->GetMethodID(hashmap_class, "size", "()I"); - int element_size = 0; // get all element size in num_rows of map column - for (int i = 0; i < num_rows; ++i) { - jobject obj = env->GetObjectArrayElement(result_obj, i); - if (obj == nullptr) { - continue; - } - element_size = element_size + env->CallIntMethod(obj, map_size); - env->DeleteLocalRef(obj); - } - map_key_column_nullable.resize(element_size); - memset(key_null_map_data.data(), 0, element_size); - map_value_column_nullable.resize(element_size); - memset(value_null_map_data.data(), 0, element_size); env->CallNonvirtualVoidMethod(jni_ctx->executor, jni_env->executor_cl, jni_env->executor_result_map_batch_id, result_nullable, num_rows, result_obj, nullmap_address, offset_address, diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java index eae5270872cc3f..20f36866c8b168 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java @@ -21,6 +21,7 @@ import org.apache.doris.catalog.Type; import org.apache.doris.common.exception.InternalException; import org.apache.doris.common.exception.UdfRuntimeException; +import org.apache.doris.common.jni.utils.JNINativeMethod; import org.apache.doris.common.jni.utils.UdfUtils; import org.apache.doris.common.jni.utils.UdfUtils.JavaUdfDataType; import org.apache.doris.thrift.TJavaUdfExecutorCtorParams; @@ -73,24 +74,6 @@ public abstract class BaseExecutor { // The JavaUdfDataType enum maps it to corresponding primitive type. protected JavaUdfDataType[] argTypes; protected JavaUdfDataType retType; - - // Input buffer from the backend. This is valid for the duration of an - // evaluate() call. - // These buffers are allocated in the BE. - protected final long inputBufferPtrs; - protected final long inputNullsPtrs; - protected final long inputOffsetsPtrs; - protected final long inputArrayNullsPtrs; - protected final long inputArrayStringOffsetsPtrs; - - // Output buffer to return non-string values. These buffers are allocated in the - // BE. - protected final long outputBufferPtr; - protected final long outputNullPtr; - protected final long outputOffsetsPtr; - protected final long outputArrayNullPtr; - protected final long outputArrayStringOffsetsPtr; - protected final long outputIntermediateStatePtr; protected Class[] argClass; protected MethodAccess methodAccess; @@ -108,18 +91,6 @@ public BaseExecutor(byte[] thriftParams) throws Exception { } catch (TException e) { throw new InternalException(e.getMessage()); } - inputBufferPtrs = request.input_buffer_ptrs; - inputNullsPtrs = request.input_nulls_ptrs; - inputOffsetsPtrs = request.input_offsets_ptrs; - inputArrayNullsPtrs = request.input_array_nulls_buffer_ptr; - inputArrayStringOffsetsPtrs = request.input_array_string_offsets_ptrs; - outputBufferPtr = request.output_buffer_ptr; - outputNullPtr = request.output_null_ptr; - outputOffsetsPtr = request.output_offsets_ptr; - outputIntermediateStatePtr = request.output_intermediate_state_ptr; - outputArrayNullPtr = request.output_array_null_ptr; - outputArrayStringOffsetsPtr = request.output_array_string_offsets_ptr; - Type[] parameterTypes = new Type[request.fn.arg_types.size()]; for (int i = 0; i < request.fn.arg_types.size(); ++i) { parameterTypes[i] = Type.fromThrift(request.fn.arg_types.get(i)); @@ -132,359 +103,6 @@ public BaseExecutor(byte[] thriftParams) throws Exception { protected abstract void init(TJavaUdfExecutorCtorParams request, String jarPath, Type funcRetType, Type... parameterTypes) throws UdfRuntimeException; - protected Object[] allocateInputObjects(long row, int argClassOffset) throws UdfRuntimeException { - Object[] inputObjects = new Object[argTypes.length]; - - for (int i = 0; i < argTypes.length; ++i) { - if (UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) != -1 - && (UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) + row) == 1)) { - inputObjects[i] = null; - continue; - } - switch (argTypes[i]) { - case BOOLEAN: - inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row); - break; - case TINYINT: - inputObjects[i] = UdfUtils.UNSAFE.getByte(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row); - break; - case SMALLINT: - inputObjects[i] = UdfUtils.UNSAFE.getShort(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case INT: - inputObjects[i] = UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case BIGINT: - inputObjects[i] = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case FLOAT: - inputObjects[i] = UdfUtils.UNSAFE.getFloat(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case DOUBLE: - inputObjects[i] = UdfUtils.UNSAFE.getDouble(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case DATE: { - long data = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateToJavaDate(data, argClass[i + argClassOffset]); - break; - } - case DATETIME: { - long data = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateTimeToJavaDateTime(data, argClass[i + argClassOffset]); - break; - } - case DATEV2: { - int data = UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateV2ToJavaDate(data, argClass[i + argClassOffset]); - break; - } - case DATETIMEV2: { - long data = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateTimeV2ToJavaDateTime(data, argClass[i + argClassOffset]); - break; - } - case LARGEINT: { - long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row; - byte[] bytes = new byte[argTypes[i].getLen()]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen()); - - inputObjects[i] = new BigInteger(UdfUtils.convertByteOrder(bytes)); - break; - } - case DECIMALV2: - case DECIMAL32: - case DECIMAL64: - case DECIMAL128: { - long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row; - byte[] bytes = new byte[argTypes[i].getLen()]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen()); - - BigInteger value = new BigInteger(UdfUtils.convertByteOrder(bytes)); - inputObjects[i] = new BigDecimal(value, argTypes[i].getScale()); - break; - } - case CHAR: - case VARCHAR: - case STRING: { - long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * row)); - long numBytes = row == 0 ? offset - : offset - Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - 1))); - long base = row == 0 - ? UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - : UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + offset - numBytes; - byte[] bytes = new byte[(int) numBytes]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes); - inputObjects[i] = new String(bytes, StandardCharsets.UTF_8); - break; - } - case ARRAY_TYPE: { - Type type = argTypes[i].getItemType(); - inputObjects[i] = arrayTypeInputData(type, i, row); - break; - } - default: - throw new UdfRuntimeException("Unsupported argument type: " + argTypes[i]); - } - } - return inputObjects; - } - - public ArrayList arrayTypeInputData(Type type, int argIdx, long row) - throws UdfRuntimeException { - long offsetStart = (row == 0) ? 0 - : Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs, argIdx)) + 8L * (row - 1))); - long offsetEnd = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs, argIdx)) + 8L * row)); - long arrayNullMapBase = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputArrayNullsPtrs, argIdx)); - long arrayInputBufferBase = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, argIdx)); - - switch (type.getPrimitiveType()) { - case BOOLEAN: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - boolean value = UdfUtils.UNSAFE.getBoolean(null, arrayInputBufferBase + offsetRow); - data.add(value); - } - } - return data; - } - case TINYINT: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - byte value = UdfUtils.UNSAFE.getByte(null, arrayInputBufferBase + offsetRow); - data.add(value); - } - } - return data; - } - case SMALLINT: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - short value = UdfUtils.UNSAFE.getShort(null, arrayInputBufferBase + 2L * offsetRow); - data.add(value); - } - } - return data; - } - case INT: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - int value = UdfUtils.UNSAFE.getInt(null, arrayInputBufferBase + 4L * offsetRow); - data.add(value); - } - } - return data; - } - case BIGINT: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - long value = UdfUtils.UNSAFE.getLong(null, arrayInputBufferBase + 8L * offsetRow); - data.add(value); - } - } - return data; - } - case FLOAT: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - float value = UdfUtils.UNSAFE.getFloat(null, arrayInputBufferBase + 4L * offsetRow); - data.add(value); - } - } - return data; - } - case DOUBLE: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - double value = UdfUtils.UNSAFE.getDouble(null, arrayInputBufferBase + 8L * offsetRow); - data.add(value); - } - } - return data; - } - case DATE: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - long value = UdfUtils.UNSAFE.getLong(null, arrayInputBufferBase + 8L * offsetRow); - // TODO: now argClass[argIdx + argClassOffset] is java.util.ArrayList, can't get - // nested class type - // LocalDate obj = UdfUtils.convertDateToJavaDate(value, argClass[argIdx + - // argClassOffset]); - LocalDate obj = (LocalDate) UdfUtils.convertDateToJavaDate(value, LocalDate.class); - data.add(obj); - } - } - return data; - } - case DATETIME: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - long value = UdfUtils.UNSAFE.getLong(null, arrayInputBufferBase + 8L * offsetRow); - // Object obj = UdfUtils.convertDateTimeToJavaDateTime(value, argClass[argIdx + - // argClassOffset]); - LocalDateTime obj = (LocalDateTime) UdfUtils.convertDateTimeToJavaDateTime(value, - LocalDateTime.class); - data.add(obj); - } - } - return data; - } - case DATEV2: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - int value = UdfUtils.UNSAFE.getInt(null, arrayInputBufferBase + 4L * offsetRow); - // Object obj = UdfUtils.convertDateV2ToJavaDate(value, argClass[argIdx + - // argClassOffset]); - LocalDate obj = (LocalDate) UdfUtils.convertDateV2ToJavaDate(value, LocalDate.class); - data.add(obj); - } - } - return data; - } - case DATETIMEV2: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - long value = UdfUtils.UNSAFE.getLong(null, arrayInputBufferBase + 8L * offsetRow); - LocalDateTime obj = (LocalDateTime) UdfUtils.convertDateTimeV2ToJavaDateTime(value, - LocalDateTime.class); - data.add(obj); - } - } - return data; - } - case LARGEINT: { - ArrayList data = new ArrayList<>(); - byte[] bytes = new byte[16]; - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - long value = UdfUtils.UNSAFE.getLong(null, arrayInputBufferBase + 16L * offsetRow); - UdfUtils.copyMemory(null, value, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16); - data.add(new BigInteger(UdfUtils.convertByteOrder(bytes))); - } - } - return data; - } - case DECIMALV2: - case DECIMAL32: - case DECIMAL64: - case DECIMAL128: { - int len; - if (type.getPrimitiveType() == PrimitiveType.DECIMAL32) { - len = 4; - } else if (type.getPrimitiveType() == PrimitiveType.DECIMAL64) { - len = 8; - } else { - len = 16; - } - ArrayList data = new ArrayList<>(); - byte[] bytes = new byte[len]; - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - long value = UdfUtils.UNSAFE.getLong(null, arrayInputBufferBase + len * offsetRow); - UdfUtils.copyMemory(null, value, bytes, UdfUtils.BYTE_ARRAY_OFFSET, len); - BigInteger bigInteger = new BigInteger(UdfUtils.convertByteOrder(bytes)); - data.add(new BigDecimal(bigInteger, argTypes[argIdx].getScale())); - } - } - return data; - } - case CHAR: - case VARCHAR: - case STRING: { - ArrayList data = new ArrayList<>(); - long strOffsetBase = UdfUtils.UNSAFE - .getLong(null, UdfUtils.getAddressAtOffset(inputArrayStringOffsetsPtrs, argIdx)); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - long stringOffsetStart = (offsetRow == 0) ? 0 - : Integer.toUnsignedLong( - UdfUtils.UNSAFE.getInt(null, strOffsetBase + 4L * (offsetRow - 1))); - long stringOffsetEnd = Integer - .toUnsignedLong(UdfUtils.UNSAFE.getInt(null, strOffsetBase + 4L * offsetRow)); - - long numBytes = stringOffsetEnd - stringOffsetStart; - long base = arrayInputBufferBase + stringOffsetStart; - byte[] bytes = new byte[(int) numBytes]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes); - data.add(new String(bytes, StandardCharsets.UTF_8)); - } - } - return data; - } - default: - throw new UdfRuntimeException("Unsupported argument type in nested array: " + type); - } - } - - protected abstract long getCurrentOutputOffset(long row, boolean isArrayType); - /** * Close the class loader we may have created. */ @@ -502,76 +120,74 @@ public void close() { classLoader = null; } - // Sets the result object 'obj' into the outputBufferPtr and outputNullPtr_ - protected boolean storeUdfResult(Object obj, long row, Class retClass) throws UdfRuntimeException { - if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) != -1) { - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputNullPtr) + row, (byte) 0); - } + public void copyTupleBasicResult(Object obj, long row, Class retClass, + long outputBufferBase, long charsAddress, long offsetsAddr, JavaUdfDataType retType) + throws UdfRuntimeException { switch (retType) { case BOOLEAN: { boolean val = (boolean) obj; - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + UdfUtils.UNSAFE.putByte(outputBufferBase + row * retType.getLen(), val ? (byte) 1 : 0); - return true; + break; } case TINYINT: { - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + UdfUtils.UNSAFE.putByte(outputBufferBase + row * retType.getLen(), (byte) obj); - return true; + break; } case SMALLINT: { - UdfUtils.UNSAFE.putShort(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + UdfUtils.UNSAFE.putShort(outputBufferBase + row * retType.getLen(), (short) obj); - return true; + break; } case INT: { - UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + UdfUtils.UNSAFE.putInt(outputBufferBase + row * retType.getLen(), (int) obj); - return true; + break; } case BIGINT: { - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + UdfUtils.UNSAFE.putLong(outputBufferBase + row * retType.getLen(), (long) obj); - return true; + break; } case FLOAT: { - UdfUtils.UNSAFE.putFloat(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + UdfUtils.UNSAFE.putFloat(outputBufferBase + row * retType.getLen(), (float) obj); - return true; + break; } case DOUBLE: { - UdfUtils.UNSAFE.putDouble(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + UdfUtils.UNSAFE.putDouble(outputBufferBase + row * retType.getLen(), (double) obj); - return true; + break; } case DATE: { long time = UdfUtils.convertToDate(obj, retClass); - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); - return true; + UdfUtils.UNSAFE.putLong(outputBufferBase + row * retType.getLen(), time); + break; } case DATETIME: { long time = UdfUtils.convertToDateTime(obj, retClass); - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); - return true; + UdfUtils.UNSAFE.putLong(outputBufferBase + row * retType.getLen(), time); + break; } case DATEV2: { int time = UdfUtils.convertToDateV2(obj, retClass); - UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); - return true; + UdfUtils.UNSAFE.putInt(outputBufferBase + row * retType.getLen(), time); + break; } case DATETIMEV2: { long time = UdfUtils.convertToDateTimeV2(obj, retClass); - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); - return true; + UdfUtils.UNSAFE.putLong(outputBufferBase + row * retType.getLen(), time); + break; } case LARGEINT: { BigInteger data = (BigInteger) obj; byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray()); - //here value is 16 bytes, so if result data greater than the maximum of 16 bytes - //it will return a wrong num to backend; + // here value is 16 bytes, so if result data greater than the maximum of 16 + // bytesit will return a wrong num to backend; byte[] value = new byte[16]; - //check data is negative + // check data is negative if (data.signum() == -1) { Arrays.fill(value, (byte) -1); } @@ -580,14 +196,14 @@ protected boolean storeUdfResult(Object obj, long row, Class retClass) throws Ud } UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length); - return true; + outputBufferBase + row * retType.getLen(), value.length); + break; } case DECIMALV2: { BigDecimal retValue = ((BigDecimal) obj).setScale(9, RoundingMode.HALF_EVEN); BigInteger data = retValue.unscaledValue(); byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray()); - //TODO: here is maybe overflow also, and may find a better way to handle + // TODO: here is maybe overflow also, and may find a better way to handle byte[] value = new byte[16]; if (data.signum() == -1) { Arrays.fill(value, (byte) -1); @@ -598,8 +214,8 @@ protected boolean storeUdfResult(Object obj, long row, Class retClass) throws Ud } UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length); - return true; + outputBufferBase + row * retType.getLen(), value.length); + break; } case DECIMAL32: case DECIMAL64: @@ -607,7 +223,7 @@ protected boolean storeUdfResult(Object obj, long row, Class retClass) throws Ud BigDecimal retValue = ((BigDecimal) obj).setScale(retType.getScale(), RoundingMode.HALF_EVEN); BigInteger data = retValue.unscaledValue(); byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray()); - //TODO: here is maybe overflow also, and may find a better way to handle + // TODO: here is maybe overflow also, and may find a better way to handle byte[] value = new byte[retType.getLen()]; if (data.signum() == -1) { Arrays.fill(value, (byte) -1); @@ -618,413 +234,29 @@ protected boolean storeUdfResult(Object obj, long row, Class retClass) throws Ud } UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length); - return true; + outputBufferBase + row * retType.getLen(), value.length); + break; } case CHAR: case VARCHAR: case STRING: { - long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr); byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8); - long offset = getCurrentOutputOffset(row, false); - if (offset + bytes.length > bufferSize) { - return false; - } + long offset = UdfUtils.UNSAFE.getInt(null, offsetsAddr + 4L * (row - 1)); + int needLen = (int) (offset + bytes.length); + outputBufferBase = JNINativeMethod.resizeStringColumn(charsAddress, needLen); offset += bytes.length; - UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * row, - Integer.parseUnsignedInt(String.valueOf(offset))); - UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + offset - bytes.length, bytes.length); + UdfUtils.UNSAFE.putInt(null, offsetsAddr + 4L * row, Integer.parseUnsignedInt(String.valueOf(offset))); + UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, outputBufferBase + offset - bytes.length, + bytes.length); updateOutputOffset(offset); - return true; - } - case ARRAY_TYPE: { - Type type = retType.getItemType(); - return arrayTypeOutputData(obj, type, row); + break; } + case ARRAY_TYPE: default: throw new UdfRuntimeException("Unsupported return type: " + retType); } } - public boolean arrayTypeOutputData(Object obj, Type type, long row) throws UdfRuntimeException { - long offset = getCurrentOutputOffset(row, true); - long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr); - long outputNullMapBase = UdfUtils.UNSAFE.getLong(null, outputArrayNullPtr); - long outputBufferBase = UdfUtils.UNSAFE.getLong(null, outputBufferPtr); - switch (type.getPrimitiveType()) { - case BOOLEAN: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - Boolean value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - UdfUtils.UNSAFE.putByte(outputBufferBase + (offset + i), value ? (byte) 1 : 0); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case TINYINT: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - Byte value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - UdfUtils.UNSAFE.putByte(outputBufferBase + (offset + i), value); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case SMALLINT: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - Short value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - UdfUtils.UNSAFE.putShort(outputBufferBase + ((offset + i) * 2L), value); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case INT: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - Integer value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - UdfUtils.UNSAFE.putInt(outputBufferBase + ((offset + i) * 4L), value); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case BIGINT: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - Long value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - UdfUtils.UNSAFE.putLong(outputBufferBase + ((offset + i) * 8L), value); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case FLOAT: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - Float value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - UdfUtils.UNSAFE.putFloat(outputBufferBase + ((offset + i) * 4L), value); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case DOUBLE: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - Double value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - UdfUtils.UNSAFE.putDouble(outputBufferBase + ((offset + i) * 8L), value); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case DATE: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - LocalDate value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - long time = UdfUtils.convertToDate(value, LocalDate.class); - UdfUtils.UNSAFE.putLong(outputBufferBase + ((offset + i) * 8L), time); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case DATETIME: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - LocalDateTime value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - long time = UdfUtils.convertToDateTime(value, LocalDateTime.class); - UdfUtils.UNSAFE.putLong(outputBufferBase + ((offset + i) * 8L), time); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case DATEV2: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - LocalDate value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - int time = UdfUtils.convertToDateV2(value, LocalDate.class); - UdfUtils.UNSAFE.putInt(outputBufferBase + ((offset + i) * 4L), time); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case DATETIMEV2: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - LocalDateTime value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - long time = UdfUtils.convertToDateTimeV2(value, LocalDateTime.class); - UdfUtils.UNSAFE.putLong(outputBufferBase + ((offset + i) * 8L), time); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case LARGEINT: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - BigInteger bigInteger = data.get(i); - if (bigInteger == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - byte[] bytes = UdfUtils.convertByteOrder(bigInteger.toByteArray()); - byte[] value = new byte[16]; - // check data is negative - if (bigInteger.signum() == -1) { - Arrays.fill(value, (byte) -1); - } - for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { - value[index] = bytes[index]; - } - UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - outputBufferBase + ((offset + i) * 16L), value.length); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case DECIMALV2: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - BigDecimal bigDecimal = data.get(i); - if (bigDecimal == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - BigInteger bigInteger = bigDecimal.setScale(9, RoundingMode.HALF_EVEN).unscaledValue(); - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - byte[] bytes = UdfUtils.convertByteOrder(bigInteger.toByteArray()); - byte[] value = new byte[16]; - // check data is negative - if (bigInteger.signum() == -1) { - Arrays.fill(value, (byte) -1); - } - for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { - value[index] = bytes[index]; - } - UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - outputBufferBase + ((offset + i) * 16L), value.length); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case DECIMAL32: - case DECIMAL64: - case DECIMAL128: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - BigDecimal bigDecimal = data.get(i); - if (bigDecimal == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - BigInteger bigInteger = bigDecimal.setScale(retType.getScale(), RoundingMode.HALF_EVEN) - .unscaledValue(); - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - byte[] bytes = UdfUtils.convertByteOrder(bigInteger.toByteArray()); - byte[] value = new byte[16]; - // check data is negative - if (bigInteger.signum() == -1) { - Arrays.fill(value, (byte) -1); - } - for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { - value[index] = bytes[index]; - } - UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - outputBufferBase + ((offset + i) * 16L), value.length); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case CHAR: - case VARCHAR: - case STRING: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - long outputStrOffsetBase = UdfUtils.UNSAFE.getLong(null, outputArrayStringOffsetsPtr); - for (int i = 0; i < num; ++i) { - String value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - byte[] bytes = value.getBytes(StandardCharsets.UTF_8); - long strOffset = (offset + i == 0) ? 0 - : Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, - outputStrOffsetBase + ((offset + i - 1) * 4L))); - if (strOffset + bytes.length > bufferSize) { - return false; - } - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - strOffset += bytes.length; - UdfUtils.UNSAFE.putInt(null, outputStrOffsetBase + 4L * (offset + i), - Integer.parseUnsignedInt(String.valueOf(strOffset))); - UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, - outputBufferBase + strOffset - bytes.length, bytes.length); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - default: - throw new UdfRuntimeException("Unsupported argument type in nested array: " + type); - } - } protected void updateOutputOffset(long offset) { } @@ -1556,120 +788,129 @@ public void copyBatchArrayResultImpl(boolean isNullable, int numRows, Object[] r PrimitiveType type) { long hasPutElementNum = 0; for (int row = 0; row < numRows; ++row) { - switch (type) { - case BOOLEAN: { - hasPutElementNum = UdfConvert - .copyBatchArrayBooleanResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case TINYINT: { - hasPutElementNum = UdfConvert - .copyBatchArrayTinyIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case SMALLINT: { - hasPutElementNum = UdfConvert - .copyBatchArraySmallIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case INT: { - hasPutElementNum = UdfConvert - .copyBatchArrayIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case BIGINT: { - hasPutElementNum = UdfConvert - .copyBatchArrayBigIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case LARGEINT: { - hasPutElementNum = UdfConvert - .copyBatchArrayLargeIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case FLOAT: { - hasPutElementNum = UdfConvert - .copyBatchArrayFloatResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DOUBLE: { - hasPutElementNum = UdfConvert - .copyBatchArrayDoubleResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case CHAR: - case VARCHAR: - case STRING: { - hasPutElementNum = UdfConvert - .copyBatchArrayStringResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr, strOffsetAddr); - break; - } - case DATE: { - hasPutElementNum = UdfConvert - .copyBatchArrayDateResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DATETIME: { - hasPutElementNum = UdfConvert - .copyBatchArrayDateTimeResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DATEV2: { - hasPutElementNum = UdfConvert - .copyBatchArrayDateV2Result(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DATETIMEV2: { - hasPutElementNum = UdfConvert - .copyBatchArrayDateTimeV2Result(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DECIMALV2: { - hasPutElementNum = UdfConvert - .copyBatchArrayDecimalResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DECIMAL32: { - hasPutElementNum = UdfConvert - .copyBatchArrayDecimalV3Result(retType.getScale(), 4L, hasPutElementNum, isNullable, row, - result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DECIMAL64: { - hasPutElementNum = UdfConvert - .copyBatchArrayDecimalV3Result(retType.getScale(), 8L, hasPutElementNum, isNullable, row, - result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DECIMAL128: { - hasPutElementNum = UdfConvert - .copyBatchArrayDecimalV3Result(retType.getScale(), 16L, hasPutElementNum, isNullable, row, - result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - default: { - Preconditions.checkState(false, "Not support type in array: " + retType); - break; - } + hasPutElementNum = copyTupleArrayResultImpl(hasPutElementNum, isNullable, row, result[row], nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr, strOffsetAddr, type); + } + } + + public long copyTupleArrayResultImpl(long hasPutElementNum, boolean isNullable, int row, Object result, + long nullMapAddr, + long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr, + PrimitiveType type) { + switch (type) { + case BOOLEAN: { + hasPutElementNum = UdfConvert + .copyBatchArrayBooleanResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case TINYINT: { + hasPutElementNum = UdfConvert + .copyBatchArrayTinyIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case SMALLINT: { + hasPutElementNum = UdfConvert + .copyBatchArraySmallIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case INT: { + hasPutElementNum = UdfConvert + .copyBatchArrayIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case BIGINT: { + hasPutElementNum = UdfConvert + .copyBatchArrayBigIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case LARGEINT: { + hasPutElementNum = UdfConvert + .copyBatchArrayLargeIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case FLOAT: { + hasPutElementNum = UdfConvert + .copyBatchArrayFloatResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DOUBLE: { + hasPutElementNum = UdfConvert + .copyBatchArrayDoubleResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case CHAR: + case VARCHAR: + case STRING: { + hasPutElementNum = UdfConvert + .copyBatchArrayStringResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr, strOffsetAddr); + break; + } + case DATE: { + hasPutElementNum = UdfConvert + .copyBatchArrayDateResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DATETIME: { + hasPutElementNum = UdfConvert + .copyBatchArrayDateTimeResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DATEV2: { + hasPutElementNum = UdfConvert + .copyBatchArrayDateV2Result(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DATETIMEV2: { + hasPutElementNum = UdfConvert + .copyBatchArrayDateTimeV2Result(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMALV2: { + hasPutElementNum = UdfConvert + .copyBatchArrayDecimalResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMAL32: { + hasPutElementNum = UdfConvert + .copyBatchArrayDecimalV3Result(retType.getScale(), 4L, hasPutElementNum, isNullable, row, + result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMAL64: { + hasPutElementNum = UdfConvert + .copyBatchArrayDecimalV3Result(retType.getScale(), 8L, hasPutElementNum, isNullable, row, + result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMAL128: { + hasPutElementNum = UdfConvert + .copyBatchArrayDecimalV3Result(retType.getScale(), 16L, hasPutElementNum, isNullable, row, + result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + default: { + Preconditions.checkState(false, "Not support type in array: " + retType); + break; } } + return hasPutElementNum; } public void buildArrayListFromHashMap(Object[] result, PrimitiveType keyType, PrimitiveType valueType, diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java index dff689ed4038d6..fa19ad32888d2f 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java @@ -163,43 +163,6 @@ public void addBatchPlaces(int rowStart, int rowEnd, long placeAddr, int offset, } } - /** - * invoke add function, add row in loop [rowStart, rowEnd). - */ - public void add(boolean isSinglePlace, long rowStart, long rowEnd) throws UdfRuntimeException { - try { - long idx = rowStart; - do { - Long curPlace = null; - if (isSinglePlace) { - curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr)); - } else { - curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr) + 8L * idx); - } - Object[] inputArgs = new Object[argTypes.length + 1]; - Object state = stateObjMap.get(curPlace); - if (state != null) { - inputArgs[0] = state; - } else { - Object newState = createAggState(); - stateObjMap.put(curPlace, newState); - inputArgs[0] = newState; - } - do { - Object[] inputObjects = allocateInputObjects(idx, 1); - for (int i = 0; i < argTypes.length; ++i) { - inputArgs[i + 1] = inputObjects[i]; - } - allMethods.get(UDAF_ADD_FUNCTION).invoke(udf, inputArgs); - idx++; - } while (isSinglePlace && idx < rowEnd); - } while (idx < rowEnd); - } catch (Exception e) { - LOG.warn("invoke add function meet some error: " + e.getCause().toString()); - throw new UdfRuntimeException("UDAF failed to add: ", e); - } - } - /** * invoke user create function to get obj. */ @@ -292,40 +255,71 @@ public void merge(long place, byte[] data) throws UdfRuntimeException { /** * invoke getValue to return finally result. */ - public boolean getValue(long row, long place) throws UdfRuntimeException { + + public Object getValue(long place) throws UdfRuntimeException { try { if (stateObjMap.get(place) == null) { stateObjMap.put(place, createAggState()); } - return storeUdfResult(allMethods.get(UDAF_RESULT_FUNCTION).invoke(udf, stateObjMap.get((Long) place)), - row, retClass); + return allMethods.get(UDAF_RESULT_FUNCTION).invoke(udf, stateObjMap.get((Long) place)); } catch (Exception e) { LOG.warn("invoke getValue function meet some error: " + e.getCause().toString()); throw new UdfRuntimeException("UDAF failed to result", e); } } - @Override - protected boolean storeUdfResult(Object obj, long row, Class retClass) throws UdfRuntimeException { - if (obj == null) { - // If result is null, return true directly when row == 0 as we have already inserted default value. - if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) == -1) { + public void copyTupleBasicResult(Object result, int row, long outputNullMapPtr, long outputBufferBase, + long charsAddress, + long offsetsAddr) throws UdfRuntimeException { + if (result == null) { + // put null obj + if (outputNullMapPtr == -1) { throw new UdfRuntimeException("UDAF failed to store null data to not null column"); + } else { + UdfUtils.UNSAFE.putByte(outputNullMapPtr + row, (byte) 1); } - return true; + return; + } + try { + if (outputNullMapPtr != -1) { + UdfUtils.UNSAFE.putByte(outputNullMapPtr + row, (byte) 0); + } + copyTupleBasicResult(result, row, retClass, outputBufferBase, charsAddress, + offsetsAddr, retType); + } catch (UdfRuntimeException e) { + LOG.info(e.toString()); } - return super.storeUdfResult(obj, row, retClass); } - @Override - protected long getCurrentOutputOffset(long row, boolean isArrayType) { - if (isArrayType) { - return Integer.toUnsignedLong( - UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * (row - 1))); - } else { - return Integer.toUnsignedLong( - UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * (row - 1))); + public void copyTupleArrayResult(long hasPutElementNum, boolean isNullable, int row, Object result, + long nullMapAddr, + long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) throws UdfRuntimeException { + if (nullMapAddr > 0) { + UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 0); + } + copyTupleArrayResultImpl(hasPutElementNum, isNullable, row, result, nullMapAddr, offsetsAddr, nestedNullMapAddr, + dataAddr, strOffsetAddr, retType.getItemType().getPrimitiveType()); + } + + public void copyTupleMapResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, + long offsetsAddr, + long keyNsestedNullMapAddr, long keyDataAddr, long keyStrOffsetAddr, + long valueNsestedNullMapAddr, long valueDataAddr, long valueStrOffsetAddr) throws UdfRuntimeException { + if (nullMapAddr > 0) { + UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 0); } + PrimitiveType keyType = retType.getKeyType().getPrimitiveType(); + PrimitiveType valueType = retType.getValueType().getPrimitiveType(); + Object[] keyCol = new Object[1]; + Object[] valueCol = new Object[1]; + Object[] resultArr = new Object[1]; + resultArr[0] = result; + buildArrayListFromHashMap(resultArr, keyType, valueType, keyCol, valueCol); + copyTupleArrayResultImpl(hasPutElementNum, isNullable, row, + valueCol[0], nullMapAddr, offsetsAddr, + valueNsestedNullMapAddr, valueDataAddr, valueStrOffsetAddr, valueType); + copyTupleArrayResultImpl(hasPutElementNum, isNullable, row, keyCol[0], nullMapAddr, offsetsAddr, + keyNsestedNullMapAddr, keyDataAddr, keyStrOffsetAddr, keyType); } @Override diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfConvert.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfConvert.java index dc835408859fbe..7b3a151f0065af 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfConvert.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfConvert.java @@ -707,9 +707,9 @@ public static void copyBatchStringResult(boolean isNullable, int numRows, String //////////////////////////////////// copyBatchArray////////////////////////////////////////////////////////// - public static long copyBatchArrayBooleanResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayBooleanResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -741,9 +741,9 @@ public static long copyBatchArrayBooleanResult(long hasPutElementNum, boolean is return hasPutElementNum; } - public static long copyBatchArrayTinyIntResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayTinyIntResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -775,9 +775,9 @@ public static long copyBatchArrayTinyIntResult(long hasPutElementNum, boolean is return hasPutElementNum; } - public static long copyBatchArraySmallIntResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArraySmallIntResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -809,9 +809,9 @@ public static long copyBatchArraySmallIntResult(long hasPutElementNum, boolean i return hasPutElementNum; } - public static long copyBatchArrayIntResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayIntResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -843,9 +843,9 @@ public static long copyBatchArrayIntResult(long hasPutElementNum, boolean isNull return hasPutElementNum; } - public static long copyBatchArrayBigIntResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayBigIntResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -877,9 +877,9 @@ public static long copyBatchArrayBigIntResult(long hasPutElementNum, boolean isN return hasPutElementNum; } - public static long copyBatchArrayFloatResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayFloatResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -911,9 +911,9 @@ public static long copyBatchArrayFloatResult(long hasPutElementNum, boolean isNu return hasPutElementNum; } - public static long copyBatchArrayDoubleResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayDoubleResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -945,9 +945,9 @@ public static long copyBatchArrayDoubleResult(long hasPutElementNum, boolean isN return hasPutElementNum; } - public static long copyBatchArrayDateResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayDateResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -981,9 +981,9 @@ public static long copyBatchArrayDateResult(long hasPutElementNum, boolean isNul return hasPutElementNum; } - public static long copyBatchArrayDateTimeResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayDateTimeResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -1017,9 +1017,9 @@ public static long copyBatchArrayDateTimeResult(long hasPutElementNum, boolean i return hasPutElementNum; } - public static long copyBatchArrayDateV2Result(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayDateV2Result(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -1054,9 +1054,9 @@ public static long copyBatchArrayDateV2Result(long hasPutElementNum, boolean isN } public static long copyBatchArrayDateTimeV2Result(long hasPutElementNum, boolean isNullable, int row, - Object[] result, + Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -1090,9 +1090,9 @@ public static long copyBatchArrayDateTimeV2Result(long hasPutElementNum, boolean return hasPutElementNum; } - public static long copyBatchArrayLargeIntResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayLargeIntResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -1140,9 +1140,9 @@ public static long copyBatchArrayLargeIntResult(long hasPutElementNum, boolean i return hasPutElementNum; } - public static long copyBatchArrayDecimalResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayDecimalResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -1194,9 +1194,9 @@ public static long copyBatchArrayDecimalResult(long hasPutElementNum, boolean is public static long copyBatchArrayDecimalV3Result(int scale, long typeLen, long hasPutElementNum, boolean isNullable, int row, - Object[] result, + Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -1247,9 +1247,9 @@ public static long copyBatchArrayDecimalV3Result(int scale, long typeLen, long h } public static long copyBatchArrayStringResult(long hasPutElementNum, boolean isNullable, int row, - Object[] result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr, + Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -1270,8 +1270,12 @@ public static long copyBatchArrayStringResult(long hasPutElementNum, boolean isN offset += byteRes[i].length; offsets[i] = offset; } - byte[] bytes = new byte[offsets[num - 1] - oldOffsetNum]; - long bytesAddr = JNINativeMethod.resizeStringColumn(dataAddr, offsets[num - 1]); + int oldSzie = 0; + if (num > 0) { + oldSzie = offsets[num - 1]; + } + byte[] bytes = new byte[oldSzie - oldOffsetNum]; + long bytesAddr = JNINativeMethod.resizeStringColumn(dataAddr, oldSzie); int dst = 0; for (int i = 0; i < num; i++) { for (int j = 0; j < byteRes[i].length; j++) { @@ -1281,7 +1285,7 @@ public static long copyBatchArrayStringResult(long hasPutElementNum, boolean isN UdfUtils.copyMemory(offsets, UdfUtils.INT_ARRAY_OFFSET, null, strOffsetAddr + (4L * hasPutElementNum), num * 4L); UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, bytesAddr + oldOffsetNum, - offsets[num - 1] - oldOffsetNum); + oldSzie - oldOffsetNum); hasPutElementNum = hasPutElementNum + num; } } else { @@ -1300,9 +1304,13 @@ public static long copyBatchArrayStringResult(long hasPutElementNum, boolean isN offset += byteRes[i].length; offsets[i] = offset; } - byte[] bytes = new byte[offsets[num - 1]]; int oldOffsetNum = UdfUtils.UNSAFE.getInt(null, strOffsetAddr + ((hasPutElementNum - 1) * 4L)); - long bytesAddr = JNINativeMethod.resizeStringColumn(dataAddr, oldOffsetNum + offsets[num - 1]); + int oldSzie = 0; + if (num > 0) { + oldSzie = offsets[num - 1]; + } + byte[] bytes = new byte[oldSzie]; + long bytesAddr = JNINativeMethod.resizeStringColumn(dataAddr, oldOffsetNum + oldSzie); int dst = 0; for (int i = 0; i < num; i++) { for (int j = 0; j < byteRes[i].length; j++) { @@ -1312,7 +1320,7 @@ public static long copyBatchArrayStringResult(long hasPutElementNum, boolean isN UdfUtils.copyMemory(offsets, UdfUtils.INT_ARRAY_OFFSET, null, strOffsetAddr + (4L * oldOffsetNum), num * 4L); UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, bytesAddr + oldOffsetNum, - offsets[num - 1]); + oldSzie); hasPutElementNum = hasPutElementNum + num; } UdfUtils.UNSAFE.putLong(null, offsetsAddr + 8L * row, hasPutElementNum); diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java index 2f6ca99fdd5120..a77b441b67d997 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java @@ -74,50 +74,6 @@ public void close() { super.close(); } - /** - * evaluate function called by the backend. The inputs to the UDF have - * been serialized to 'input' - */ - public void evaluate() throws UdfRuntimeException { - int batchSize = UdfUtils.UNSAFE.getInt(null, batchSizePtr); - try { - if (retType.equals(JavaUdfDataType.STRING) || retType.equals(JavaUdfDataType.VARCHAR) - || retType.equals(JavaUdfDataType.CHAR) || retType.equals(JavaUdfDataType.ARRAY_TYPE) - || retType.equals(JavaUdfDataType.MAP_TYPE)) { - // If this udf return variable-size type (e.g.) String, we have to allocate output - // buffer multiple times until buffer size is enough to store output column. So we - // always begin with the last evaluated row instead of beginning of this batch. - rowIdx = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr + 8); - if (rowIdx == 0) { - outputOffset = 0L; - } - } else { - rowIdx = 0; - } - for (; rowIdx < batchSize; rowIdx++) { - inputObjects = allocateInputObjects(rowIdx, 0); - // `storeUdfResult` is called to store udf result to output column. If true - // is returned, current value is stored successfully. Otherwise, current result is - // not processed successfully (e.g. current output buffer is not large enough) so - // we break this loop directly. - if (!storeUdfResult(evaluate(inputObjects), rowIdx, method.getReturnType())) { - UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8, rowIdx); - return; - } - } - } catch (Exception e) { - if (retType.equals(JavaUdfDataType.STRING) || retType.equals(JavaUdfDataType.ARRAY_TYPE) - || retType.equals(JavaUdfDataType.MAP_TYPE)) { - UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8, batchSize); - } - throw new UdfRuntimeException("UDF::evaluate() ran into a problem.", e); - } - if (retType.equals(JavaUdfDataType.STRING) || retType.equals(JavaUdfDataType.ARRAY_TYPE) - || retType.equals(JavaUdfDataType.MAP_TYPE)) { - UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8, rowIdx); - } - } - public Object[] convertBasicArguments(int argIdx, boolean isNullable, int numRows, long nullMapAddr, long columnAddr, long strOffsetAddr) { return convertBasicArg(true, argIdx, isNullable, 0, numRows, nullMapAddr, columnAddr, strOffsetAddr); @@ -211,30 +167,6 @@ public Method getMethod() { return method; } - // Sets the result object 'obj' into the outputBufferPtr and outputNullPtr_ - @Override - protected boolean storeUdfResult(Object obj, long row, Class retClass) throws UdfRuntimeException { - if (obj == null) { - if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) == -1) { - throw new UdfRuntimeException("UDF failed to store null data to not null column"); - } - UdfUtils.UNSAFE.putByte(null, UdfUtils.UNSAFE.getLong(null, outputNullPtr) + row, (byte) 1); - if (retType.equals(JavaUdfDataType.STRING)) { - UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) - + 4L * row, Integer.parseUnsignedInt(String.valueOf(outputOffset))); - } else if (retType.equals(JavaUdfDataType.ARRAY_TYPE)) { - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(outputOffset))); - } - return true; - } - return super.storeUdfResult(obj, row, retClass); - } - - @Override - protected long getCurrentOutputOffset(long row, boolean isArrayType) { - return outputOffset; - } @Override protected void updateOutputOffset(long offset) { diff --git a/fe/be-java-extensions/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java b/fe/be-java-extensions/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java deleted file mode 100644 index 7c725da50e00ee..00000000000000 --- a/fe/be-java-extensions/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java +++ /dev/null @@ -1,600 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.udf; - -import org.apache.doris.common.jni.utils.UdfUtils; -import org.apache.doris.thrift.TFunction; -import org.apache.doris.thrift.TFunctionBinaryType; -import org.apache.doris.thrift.TFunctionName; -import org.apache.doris.thrift.TJavaUdfExecutorCtorParams; -import org.apache.doris.thrift.TPrimitiveType; -import org.apache.doris.thrift.TScalarFunction; -import org.apache.doris.thrift.TScalarType; -import org.apache.doris.thrift.TTypeDesc; -import org.apache.doris.thrift.TTypeNode; -import org.apache.doris.thrift.TTypeNodeType; - -import org.apache.thrift.TSerializer; -import org.apache.thrift.protocol.TBinaryProtocol; -import org.junit.Test; - -import java.math.BigDecimal; -import java.math.BigInteger; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; - -public class UdfExecutorTest { - - @Test - public void testDateTimeUdf() throws Exception { - TScalarFunction scalarFunction = new TScalarFunction(); - scalarFunction.symbol = "org.apache.doris.udf.DateTimeUdf"; - - TFunction fn = new TFunction(); - fn.setBinaryType(TFunctionBinaryType.JAVA_UDF); - TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR); - typeNode.setScalarType(new TScalarType(TPrimitiveType.INT)); - fn.setRetType(new TTypeDesc(Collections.singletonList(typeNode))); - - TTypeNode typeNodeArg = new TTypeNode(TTypeNodeType.SCALAR); - typeNodeArg.setScalarType(new TScalarType(TPrimitiveType.DATETIME)); - TTypeDesc typeDescArg = new TTypeDesc(Collections.singletonList(typeNodeArg)); - fn.arg_types = Arrays.asList(typeDescArg); - - fn.scalar_fn = scalarFunction; - fn.name = new TFunctionName("DateTimeUdf"); - - long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(4); - int batchSize = 10; - UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); - - TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); - params.setBatchSizePtr(batchSizePtr); - params.setFn(fn); - - long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputBuffer = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); - long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); - - UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); - UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); - - params.setOutputBufferPtr(outputBufferPtr); - params.setOutputNullPtr(outputNullPtr); - - int numCols = 1; - long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - - long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(8 * batchSize); - long inputNull1 = UdfUtils.UNSAFE.allocateMemory(batchSize); - - UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1); - UdfUtils.UNSAFE.putLong(inputNullPtr, inputNull1); - - long[] inputLongDateTime = - new long[] {562960991655690406L, 563242466632401062L, 563523941609111718L, 563805416585822374L, - 564086891562533030L, 564368366539243686L, 564649841515954342L, 564931316492664998L, - 565212791469375654L, 565494266446086310L}; - - for (int i = 0; i < batchSize; ++i) { - UdfUtils.UNSAFE.putLong(null, inputBuffer1 + i * 8, inputLongDateTime[i]); - UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 0); - } - - params.setInputBufferPtrs(inputBufferPtr); - params.setInputNullsPtrs(inputNullPtr); - params.setInputOffsetsPtrs(0); - - TBinaryProtocol.Factory factory = new TBinaryProtocol.Factory(); - TSerializer serializer = new TSerializer(factory); - - UdfExecutor executor = new UdfExecutor(serializer.serialize(params)); - executor.evaluate(); - - for (int i = 0; i < batchSize; ++i) { - assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); - assert (UdfUtils.UNSAFE.getInt(outputBuffer + 4 * i) == (2000 + i)); - } - } - - @Test - public void testDecimalUdf() throws Exception { - TScalarFunction scalarFunction = new TScalarFunction(); - scalarFunction.symbol = "org.apache.doris.udf.DecimalUdf"; - TFunction fn = new TFunction(); - fn.binary_type = TFunctionBinaryType.JAVA_UDF; - TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR); - TScalarType scalarType = new TScalarType(TPrimitiveType.DECIMALV2); - scalarType.setScale(9); - scalarType.setPrecision(27); - typeNode.scalar_type = scalarType; - TTypeDesc typeDesc = new TTypeDesc(Collections.singletonList(typeNode)); - fn.ret_type = typeDesc; - fn.arg_types = Arrays.asList(typeDesc, typeDesc); - fn.scalar_fn = scalarFunction; - fn.name = new TFunctionName("DecimalUdf"); - - long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(8); - int batchSize = 10; - UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); - - TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); - params.setBatchSizePtr(batchSizePtr); - params.setFn(fn); - - long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); - - long outputBuffer = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); - long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); - - UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); - UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); - - params.setOutputBufferPtr(outputBufferPtr); - params.setOutputNullPtr(outputNullPtr); - - int numCols = 2; - long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - - long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); - long inputNull1 = UdfUtils.UNSAFE.allocateMemory(batchSize); - - long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); - long inputNull2 = UdfUtils.UNSAFE.allocateMemory(batchSize); - - UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1); - UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2); - UdfUtils.UNSAFE.putLong(inputNullPtr, inputNull1); - UdfUtils.UNSAFE.putLong(inputNullPtr + 8, inputNull2); - - long[] inputLong = - new long[] {562960991655690406L, 563242466632401062L, 563523941609111718L, 563805416585822374L, - 564086891562533030L, 564368366539243686L, 564649841515954342L, 564931316492664998L, - 565212791469375654L, 565494266446086310L}; - - BigDecimal[] decimalArray = new BigDecimal[10]; - for (int i = 0; i < batchSize; ++i) { - BigInteger temp = BigInteger.valueOf(inputLong[i]); - decimalArray[i] = new BigDecimal(temp, 9); - } - - BigDecimal decimal2 = new BigDecimal(BigInteger.valueOf(0L), 9); - byte[] intput2 = convertByteOrder(decimal2.unscaledValue().toByteArray()); - byte[] value2 = new byte[16]; - if (decimal2.signum() == -1) { - Arrays.fill(value2, (byte) -1); - } - for (int index = 0; index < Math.min(intput2.length, value2.length); ++index) { - value2[index] = intput2[index]; - } - - for (int i = 0; i < batchSize; ++i) { - byte[] intput1 = convertByteOrder(decimalArray[i].unscaledValue().toByteArray()); - byte[] value1 = new byte[16]; - if (decimalArray[i].signum() == -1) { - Arrays.fill(value1, (byte) -1); - } - for (int index = 0; index < Math.min(intput1.length, value1.length); ++index) { - value1[index] = intput1[index]; - } - UdfUtils.copyMemory(value1, UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1 + i * 16, value1.length); - UdfUtils.copyMemory(value2, UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2 + i * 16, value2.length); - UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 0); - UdfUtils.UNSAFE.putByte(null, inputNull2 + i, (byte) 0); - } - - params.setInputBufferPtrs(inputBufferPtr); - params.setInputNullsPtrs(inputNullPtr); - params.setInputOffsetsPtrs(0); - - TBinaryProtocol.Factory factory = new TBinaryProtocol.Factory(); - TSerializer serializer = new TSerializer(factory); - - UdfExecutor udfExecutor = new UdfExecutor(serializer.serialize(params)); - udfExecutor.evaluate(); - - for (int i = 0; i < batchSize; ++i) { - byte[] bytes = new byte[16]; - assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); - UdfUtils.copyMemory(null, outputBuffer + 16 * i, bytes, UdfUtils.BYTE_ARRAY_OFFSET, bytes.length); - - BigInteger integer = new BigInteger(convertByteOrder(bytes)); - BigDecimal result = new BigDecimal(integer, 9); - assert (result.equals(decimalArray[i])); - } - } - - @Test - public void testConstantOneUdf() throws Exception { - TScalarFunction scalarFunction = new TScalarFunction(); - scalarFunction.symbol = "org.apache.doris.udf.ConstantOneUdf"; - - TFunction fn = new TFunction(); - fn.binary_type = TFunctionBinaryType.JAVA_UDF; - TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR); - typeNode.scalar_type = new TScalarType(TPrimitiveType.INT); - fn.ret_type = new TTypeDesc(Collections.singletonList(typeNode)); - fn.arg_types = new ArrayList<>(); - fn.scalar_fn = scalarFunction; - fn.name = new TFunctionName("ConstantOne"); - - - long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(4); - int batchSize = 10; - UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); - TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); - params.setBatchSizePtr(batchSizePtr); - params.setFn(fn); - - long outputBuffer = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); - long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); - long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); - UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); - long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); - UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); - params.setOutputBufferPtr(outputBufferPtr); - params.setOutputNullPtr(outputNullPtr); - params.setInputBufferPtrs(0); - params.setInputNullsPtrs(0); - params.setInputOffsetsPtrs(0); - - TBinaryProtocol.Factory factory = - new TBinaryProtocol.Factory(); - TSerializer serializer = new TSerializer(factory); - - UdfExecutor executor; - executor = new UdfExecutor(serializer.serialize(params)); - - executor.evaluate(); - for (int i = 0; i < 10; i++) { - assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); - assert (UdfUtils.UNSAFE.getInt(outputBuffer + 4 * i) == 1); - } - } - - @Test - public void testSimpleAddUdf() throws Exception { - TScalarFunction scalarFunction = new TScalarFunction(); - scalarFunction.symbol = "org.apache.doris.udf.SimpleAddUdf"; - - TFunction fn = new TFunction(); - fn.binary_type = TFunctionBinaryType.JAVA_UDF; - TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR); - typeNode.scalar_type = new TScalarType(TPrimitiveType.INT); - TTypeDesc typeDesc = new TTypeDesc(Collections.singletonList(typeNode)); - fn.ret_type = typeDesc; - fn.arg_types = Arrays.asList(typeDesc, typeDesc); - fn.scalar_fn = scalarFunction; - fn.name = new TFunctionName("SimpleAdd"); - - - long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(4); - int batchSize = 10; - UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); - - TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); - params.setBatchSizePtr(batchSizePtr); - params.setFn(fn); - - long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputBuffer = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); - long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); - UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); - UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); - - params.setOutputBufferPtr(outputBufferPtr); - params.setOutputNullPtr(outputNullPtr); - - int numCols = 2; - long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - - long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); - long inputNull1 = UdfUtils.UNSAFE.allocateMemory(batchSize); - long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); - long inputNull2 = UdfUtils.UNSAFE.allocateMemory(batchSize); - - UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1); - UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2); - UdfUtils.UNSAFE.putLong(inputNullPtr, inputNull1); - UdfUtils.UNSAFE.putLong(inputNullPtr + 8, inputNull2); - - for (int i = 0; i < batchSize; i++) { - UdfUtils.UNSAFE.putInt(null, inputBuffer1 + i * 4, i); - UdfUtils.UNSAFE.putInt(null, inputBuffer2 + i * 4, i); - - if (i % 2 == 0) { - UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 0); - } - UdfUtils.UNSAFE.putByte(null, inputNull2 + i, (byte) 0); - } - params.setInputBufferPtrs(inputBufferPtr); - params.setInputNullsPtrs(inputNullPtr); - params.setInputOffsetsPtrs(0); - - TBinaryProtocol.Factory factory = - new TBinaryProtocol.Factory(); - TSerializer serializer = new TSerializer(factory); - - UdfExecutor executor; - executor = new UdfExecutor(serializer.serialize(params)); - - executor.evaluate(); - for (int i = 0; i < batchSize; i++) { - if (i % 2 == 0) { - assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 1); - } else { - assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); - assert (UdfUtils.UNSAFE.getInt(outputBuffer + 4 * i) == i * 2); - } - } - } - - @Test - public void testStringConcatUdf() throws Exception { - TScalarFunction scalarFunction = new TScalarFunction(); - scalarFunction.symbol = "org.apache.doris.udf.StringConcatUdf"; - - TFunction fn = new TFunction(); - fn.binary_type = TFunctionBinaryType.JAVA_UDF; - TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR); - typeNode.scalar_type = new TScalarType(TPrimitiveType.STRING); - TTypeDesc typeDesc = new TTypeDesc(Collections.singletonList(typeNode)); - fn.ret_type = typeDesc; - fn.arg_types = Arrays.asList(typeDesc, typeDesc); - fn.scalar_fn = scalarFunction; - fn.name = new TFunctionName("StringConcat"); - - - long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(32); - int batchSize = 10; - UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); - - TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); - params.setBatchSizePtr(batchSizePtr); - params.setFn(fn); - - long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputOffsetsPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputIntermediateStatePtr = UdfUtils.UNSAFE.allocateMemory(8 * 2); - - String[] input1 = new String[batchSize]; - String[] input2 = new String[batchSize]; - long[] inputOffsets1 = new long[batchSize]; - long[] inputOffsets2 = new long[batchSize]; - long inputBufferSize1 = 0; - long inputBufferSize2 = 0; - for (int i = 0; i < batchSize; i++) { - input1[i] = "Input1_" + i; - input2[i] = "Input2_" + i; - inputOffsets1[i] = i == 0 ? input1[i].getBytes(StandardCharsets.UTF_8).length - : inputOffsets1[i - 1] + input1[i].getBytes(StandardCharsets.UTF_8).length; - inputOffsets2[i] = i == 0 ? input2[i].getBytes(StandardCharsets.UTF_8).length - : inputOffsets2[i - 1] + input2[i].getBytes(StandardCharsets.UTF_8).length; - inputBufferSize1 += input1[i].getBytes(StandardCharsets.UTF_8).length; - inputBufferSize2 += input2[i].getBytes(StandardCharsets.UTF_8).length; - } - // In our test case, output buffer is (8 + 1) bytes * batchSize - long outputBuffer = UdfUtils.UNSAFE.allocateMemory(inputBufferSize1 + inputBufferSize2 + batchSize); - long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); - long outputOffset = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); - - UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); - UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); - UdfUtils.UNSAFE.putLong(outputOffsetsPtr, outputOffset); - // reserved buffer size - UdfUtils.UNSAFE.putLong(outputIntermediateStatePtr, inputBufferSize1 + inputBufferSize2 + batchSize); - // current row id - UdfUtils.UNSAFE.putLong(outputIntermediateStatePtr + 8, 0); - - params.setOutputBufferPtr(outputBufferPtr); - params.setOutputNullPtr(outputNullPtr); - params.setOutputOffsetsPtr(outputOffsetsPtr); - params.setOutputIntermediateStatePtr(outputIntermediateStatePtr); - - int numCols = 2; - long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - long inputOffsetsPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - - long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(inputBufferSize1 + batchSize); - long inputOffset1 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); - long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(inputBufferSize2 + batchSize); - long inputOffset2 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); - - UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1); - UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2); - UdfUtils.UNSAFE.putLong(inputNullPtr, -1); - UdfUtils.UNSAFE.putLong(inputNullPtr + 8, -1); - UdfUtils.UNSAFE.putLong(inputOffsetsPtr, inputOffset1); - UdfUtils.UNSAFE.putLong(inputOffsetsPtr + 8, inputOffset2); - - for (int i = 0; i < batchSize; i++) { - if (i == 0) { - UdfUtils.copyMemory(input1[i].getBytes(StandardCharsets.UTF_8), - UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1, - input1[i].getBytes(StandardCharsets.UTF_8).length); - UdfUtils.copyMemory(input2[i].getBytes(StandardCharsets.UTF_8), - UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2, - input2[i].getBytes(StandardCharsets.UTF_8).length); - } else { - UdfUtils.copyMemory(input1[i].getBytes(StandardCharsets.UTF_8), - UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1 + inputOffsets1[i - 1], - input1[i].getBytes(StandardCharsets.UTF_8).length); - UdfUtils.copyMemory(input2[i].getBytes(StandardCharsets.UTF_8), - UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2 + inputOffsets2[i - 1], - input2[i].getBytes(StandardCharsets.UTF_8).length); - } - UdfUtils.UNSAFE.putInt(null, inputOffset1 + 4L * i, - Integer.parseUnsignedInt(String.valueOf(inputOffsets1[i]))); - UdfUtils.UNSAFE.putInt(null, inputOffset2 + 4L * i, - Integer.parseUnsignedInt(String.valueOf(inputOffsets2[i]))); - } - params.setInputBufferPtrs(inputBufferPtr); - params.setInputNullsPtrs(inputNullPtr); - params.setInputOffsetsPtrs(inputOffsetsPtr); - - TBinaryProtocol.Factory factory = - new TBinaryProtocol.Factory(); - TSerializer serializer = new TSerializer(factory); - - UdfExecutor executor; - executor = new UdfExecutor(serializer.serialize(params)); - - executor.evaluate(); - for (int i = 0; i < batchSize; i++) { - byte[] bytes = new byte[input1[i].getBytes(StandardCharsets.UTF_8).length - + input2[i].getBytes(StandardCharsets.UTF_8).length]; - assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); - if (i == 0) { - UdfUtils.copyMemory(null, outputBuffer, bytes, UdfUtils.BYTE_ARRAY_OFFSET, - bytes.length); - } else { - long lastOffset = UdfUtils.UNSAFE.getInt(null, outputOffset + 4 * (i - 1)); - UdfUtils.copyMemory(null, outputBuffer + lastOffset, bytes, UdfUtils.BYTE_ARRAY_OFFSET, - bytes.length); - } - assert (new String(bytes, StandardCharsets.UTF_8).equals(input1[i] + input2[i])); - assert (UdfUtils.UNSAFE.getByte(null, outputNull + i) == 0); - } - } - - @Test - public void testLargeIntUdf() throws Exception { - TScalarFunction scalarFunction = new TScalarFunction(); - scalarFunction.symbol = "org.apache.doris.udf.LargeIntUdf"; - TFunction fn = new TFunction(); - fn.binary_type = TFunctionBinaryType.JAVA_UDF; - TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR); - typeNode.scalar_type = new TScalarType(TPrimitiveType.LARGEINT); - - TTypeDesc typeDesc = new TTypeDesc(Collections.singletonList(typeNode)); - - fn.ret_type = typeDesc; - fn.arg_types = Arrays.asList(typeDesc, typeDesc); - fn.scalar_fn = scalarFunction; - fn.name = new TFunctionName("LargeIntUdf"); - - long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(8); - int batchSize = 10; - UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); - - TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); - params.setBatchSizePtr(batchSizePtr); - params.setFn(fn); - - long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); - - long outputBuffer = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); - long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); - - UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); - UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); - - params.setOutputBufferPtr(outputBufferPtr); - params.setOutputNullPtr(outputNullPtr); - - int numCols = 2; - long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - - long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); - long inputNull1 = UdfUtils.UNSAFE.allocateMemory(batchSize); - - long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); - long inputNull2 = UdfUtils.UNSAFE.allocateMemory(batchSize); - - UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1); - UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2); - UdfUtils.UNSAFE.putLong(inputNullPtr, inputNull1); - UdfUtils.UNSAFE.putLong(inputNullPtr + 8, inputNull2); - - long[] inputLong = - new long[] {562960991655690406L, 563242466632401062L, 563523941609111718L, 563805416585822374L, - 564086891562533030L, 564368366539243686L, 564649841515954342L, 564931316492664998L, - 565212791469375654L, 565494266446086310L}; - - BigInteger[] integerArray = new BigInteger[10]; - for (int i = 0; i < batchSize; ++i) { - integerArray[i] = BigInteger.valueOf(inputLong[i]); - } - BigInteger integer2 = BigInteger.valueOf(1L); - byte[] intput2 = convertByteOrder(integer2.toByteArray()); - byte[] value2 = new byte[16]; - if (integer2.signum() == -1) { - Arrays.fill(value2, (byte) -1); - } - for (int index = 0; index < Math.min(intput2.length, value2.length); ++index) { - value2[index] = intput2[index]; - } - - for (int i = 0; i < batchSize; ++i) { - byte[] intput1 = convertByteOrder(integerArray[i].toByteArray()); - byte[] value1 = new byte[16]; - if (integerArray[i].signum() == -1) { - Arrays.fill(value1, (byte) -1); - } - for (int index = 0; index < Math.min(intput1.length, value1.length); ++index) { - value1[index] = intput1[index]; - } - UdfUtils.copyMemory(value1, UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1 + i * 16, value1.length); - UdfUtils.copyMemory(value2, UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2 + i * 16, value2.length); - UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 0); - UdfUtils.UNSAFE.putByte(null, inputNull2 + i, (byte) 0); - } - - params.setInputBufferPtrs(inputBufferPtr); - params.setInputNullsPtrs(inputNullPtr); - params.setInputOffsetsPtrs(0); - - TBinaryProtocol.Factory factory = new TBinaryProtocol.Factory(); - TSerializer serializer = new TSerializer(factory); - - UdfExecutor udfExecutor = new UdfExecutor(serializer.serialize(params)); - udfExecutor.evaluate(); - - for (int i = 0; i < batchSize; ++i) { - byte[] bytes = new byte[16]; - assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); - UdfUtils.copyMemory(null, outputBuffer + 16 * i, bytes, UdfUtils.BYTE_ARRAY_OFFSET, bytes.length); - BigInteger result = new BigInteger(convertByteOrder(bytes)); - assert (result.equals(integerArray[i].add(BigInteger.valueOf(1)))); - } - } - - public byte[] convertByteOrder(byte[] bytes) { - int length = bytes.length; - for (int i = 0; i < length / 2; ++i) { - byte temp = bytes[i]; - bytes[i] = bytes[length - 1 - i]; - bytes[length - 1 - i] = temp; - } - return bytes; - } -} diff --git a/regression-test/data/javaudf_p0/test_javaudaf_return_map.out b/regression-test/data/javaudf_p0/test_javaudaf_return_map.out new file mode 100644 index 00000000000000..1a4ff1b2bfd09d --- /dev/null +++ b/regression-test/data/javaudf_p0/test_javaudaf_return_map.out @@ -0,0 +1,31 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !select_1 -- +{1:10, 2:20, 3:30, 4:40, 5:50} + +-- !select_2 -- +{1:0.01, 2:0.02, 3:0.03, 4:0.04, 5:0.05} + +-- !select_3 -- +{1:10} +{2:20} +{3:30} +{4:40} +{5:50} + +-- !select_4 -- +{1:0.01} +{2:0.02} +{3:0.03} +{4:0.04} +{5:0.05} + +-- !select_5 -- +{"2 114":"0.02 514", "3 114":"0.03 514", "1 114":"0.01 514", "5 114":"0.05 514", "4 114":"0.04 514"} + +-- !select_6 -- +{"1 114":"0.01 514"} +{"2 114":"0.02 514"} +{"3 114":"0.03 514"} +{"4 114":"0.04 514"} +{"5 114":"0.05 514"} + diff --git a/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MyReturnMapString.java b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MyReturnMapString.java new file mode 100644 index 00000000000000..a416a8371e4abc --- /dev/null +++ b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MyReturnMapString.java @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +package org.apache.doris.udf; +import org.apache.log4j.Logger; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.*; + + +public class MyReturnMapString { + private static final Logger LOG = Logger.getLogger(MyReturnMapString.class); + public static class State { + public HashMap counter = new HashMap<>(); + } + + public State create() { + return new State(); + } + + public void destroy(State state) { + } + + public void add(State state, Integer k, Double v) { + LOG.info("udaf nest k v " + k + " " + v); + state.counter.put(k, v); + } + + public void serialize(State state, DataOutputStream out) throws IOException { + int size = state.counter.size(); + out.writeInt(size); + for(Map.Entry it : state.counter.entrySet()){ + out.writeInt(it.getKey()); + out.writeDouble(it.getValue()); + } + } + + public void deserialize(State state, DataInputStream in) throws IOException { + int size = in.readInt(); + for (int i = 0; i < size; ++i) { + Integer key = in.readInt(); + Double value = in.readDouble(); + state.counter.put(key, value); + } + } + + public void merge(State state, State rhs) { + for(Map.Entry it : rhs.counter.entrySet()){ + state.counter.put(it.getKey(), it.getValue()); + } + } + + public HashMap getValue(State state) { + //sort for regression test + HashMap map = new HashMap<>(); + for(Map.Entry it : state.counter.entrySet()){ + map.put(it.getKey() + " 114", it.getValue() + " 514"); + } + return map; + } +} \ No newline at end of file diff --git a/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumReturnMapInt.java b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumReturnMapInt.java new file mode 100644 index 00000000000000..cab664ef36168e --- /dev/null +++ b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumReturnMapInt.java @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +package org.apache.doris.udf; +import org.apache.log4j.Logger; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.*; + + +public class MySumReturnMapInt { + private static final Logger LOG = Logger.getLogger(MySumReturnMapInt.class); + public static class State { + public HashMap counter = new HashMap<>(); + } + + public State create() { + return new State(); + } + + public void destroy(State state) { + } + + public void add(State state, Integer val) { + if (val == null) { + return; + } + state.counter.put(val, 10 * val); + } + + public void serialize(State state, DataOutputStream out) throws IOException { + int size = state.counter.size(); + out.writeInt(size); + for(Map.Entry it : state.counter.entrySet()){ + out.writeInt(it.getKey()); + out.writeInt(it.getValue()); + } + } + + public void deserialize(State state, DataInputStream in) throws IOException { + int size = in.readInt(); + for (int i = 0; i < size; ++i) { + Integer key = in.readInt(); + Integer value = in.readInt(); + state.counter.put(key, value); + } + } + + public void merge(State state, State rhs) { + for(Map.Entry it : rhs.counter.entrySet()){ + state.counter.put(it.getKey(), it.getValue()); + } + } + + public HashMap getValue(State state) { + //sort for regression test + return state.counter; + } +} \ No newline at end of file diff --git a/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumReturnMapIntDou.java b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumReturnMapIntDou.java new file mode 100644 index 00000000000000..7a86666ef3b535 --- /dev/null +++ b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumReturnMapIntDou.java @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +package org.apache.doris.udf; +import org.apache.log4j.Logger; + +import com.carrotsearch.hppc.DoubleByteAssociativeContainer; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.*; + + +public class MySumReturnMapIntDou { + private static final Logger LOG = Logger.getLogger(MySumReturnMapIntDou.class); + public static class State { + public HashMap counter = new HashMap<>(); + } + + public State create() { + return new State(); + } + + public void destroy(State state) { + } + + public void add(State state, Integer k, Double v) { + LOG.info("udaf nest k v " + k + " " + v); + state.counter.put(k, v); + } + + public void serialize(State state, DataOutputStream out) throws IOException { + int size = state.counter.size(); + out.writeInt(size); + for(Map.Entry it : state.counter.entrySet()){ + out.writeInt(it.getKey()); + out.writeDouble(it.getValue()); + } + } + + public void deserialize(State state, DataInputStream in) throws IOException { + int size = in.readInt(); + for (int i = 0; i < size; ++i) { + Integer key = in.readInt(); + Double value = in.readDouble(); + state.counter.put(key, value); + } + } + + public void merge(State state, State rhs) { + for(Map.Entry it : rhs.counter.entrySet()){ + state.counter.put(it.getKey(), it.getValue()); + } + } + + public HashMap getValue(State state) { + //sort for regression test + return state.counter; + } +} \ No newline at end of file diff --git a/regression-test/suites/javaudf_p0/test_javaudaf_return_map.groovy b/regression-test/suites/javaudf_p0/test_javaudaf_return_map.groovy new file mode 100644 index 00000000000000..85b6d042a030ea --- /dev/null +++ b/regression-test/suites/javaudf_p0/test_javaudaf_return_map.groovy @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import org.codehaus.groovy.runtime.IOGroovyMethods + +import java.nio.charset.StandardCharsets +import java.nio.file.Files +import java.nio.file.Paths + +suite("test_javaudaf_return_map") { + def jarPath = """${context.file.parent}/jars/java-udf-case-jar-with-dependencies.jar""" + log.info("Jar path: ${jarPath}".toString()) + try { + try_sql("DROP FUNCTION IF EXISTS aggmap(int);") + try_sql("DROP FUNCTION IF EXISTS aggmap2(int,double);") + try_sql("DROP FUNCTION IF EXISTS aggmap3(int,double);") + try_sql("DROP TABLE IF EXISTS aggdb") + sql """ + CREATE TABLE IF NOT EXISTS aggdb( + `id` INT NULL COMMENT "" , + `d` Double NULL COMMENT "" + ) ENGINE=OLAP + DUPLICATE KEY(`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "storage_format" = "V2" + ); + """ + + + + sql """ INSERT INTO aggdb VALUES(1,0.01); """ + sql """ INSERT INTO aggdb VALUES(2,0.02); """ + sql """ INSERT INTO aggdb VALUES(3,0.03); """ + sql """ INSERT INTO aggdb VALUES(4,0.04); """ + sql """ INSERT INTO aggdb VALUES(5,0.05); """ + + + sql """ + + CREATE AGGREGATE FUNCTION aggmap(int) RETURNS Map PROPERTIES ( + "file"="file://${jarPath}", + "symbol"="org.apache.doris.udf.MySumReturnMapInt", + "type"="JAVA_UDF" + ); + + """ + + sql """ + + CREATE AGGREGATE FUNCTION aggmap2(int,double) RETURNS Map PROPERTIES ( + "file"="file://${jarPath}", + "symbol"="org.apache.doris.udf.MySumReturnMapIntDou", + "type"="JAVA_UDF" + ); + + + """ + + + sql """ + + CREATE AGGREGATE FUNCTION aggmap3(int,double) RETURNS Map PROPERTIES ( + "file"="file://${jarPath}", + "symbol"="org.apache.doris.udf.MyReturnMapString", + "type"="JAVA_UDF" + ); + + + """ + + qt_select_1 """ select aggmap(id) from aggdb; """ + + qt_select_2 """ select aggmap2(id,d) from aggdb; """ + + qt_select_3 """ select aggmap(id) from aggdb group by id; """ + + qt_select_4 """ select aggmap2(id,d) from aggdb group by id; """ + + qt_select_5 """ select aggmap3(id,d) from aggdb; """ + + qt_select_6 """ select aggmap3(id,d) from aggdb group by id; """ + } finally { + try_sql("DROP FUNCTION IF EXISTS aggmap(int);") + try_sql("DROP FUNCTION IF EXISTS aggmap2(int,double);") + try_sql("DROP FUNCTION IF EXISTS aggmap3(int,double);") + try_sql("DROP TABLE IF EXISTS aggdb") + } +}