From 46154f0dc7e990b327937e27295c536a07b87a0b Mon Sep 17 00:00:00 2001 From: zhangstar333 <2561612514@qq.com> Date: Sat, 7 May 2022 16:35:49 +0800 Subject: [PATCH 1/4] [Vectorized][java-udf] add datetime&&largeint&&decimal type to java-udf --- be/src/util/jni-util.cpp | 8 +- be/src/vec/functions/function_java_udf.cpp | 25 +- .../udf/java-user-defined-function.md | 19 +- .../udf/java-user-defined-function.md | 20 +- .../doris/analysis/CreateFunctionStmt.java | 8 + .../org/apache/doris/udf/UdfExecutor.java | 242 +++++++++++++++--- 6 files changed, 274 insertions(+), 48 deletions(-) diff --git a/be/src/util/jni-util.cpp b/be/src/util/jni-util.cpp index 54ca27a64d7795..94355ec9513d7c 100644 --- a/be/src/util/jni-util.cpp +++ b/be/src/util/jni-util.cpp @@ -16,8 +16,10 @@ // under the License. #include "util/jni-util.h" + #ifdef LIBJVM #include +#include "jni_md.h" #include #include "gutil/once.h" @@ -45,8 +47,10 @@ void FindOrCreateJavaVM() { vm_args.nOptions = 1; vm_args.ignoreUnrecognized = JNI_TRUE; - int res = JNI_CreateJavaVM(&g_vm, (void**)&env, &vm_args); - DCHECK_LT(res, 0) << "Failed tp create JVM, code= " << res; + jint res = JNI_CreateJavaVM(&g_vm, (void**)&env, &vm_args); + if (JNI_OK != res) { + DCHECK(false) << "Failed to create JVM, code= " << res; + } } else { CHECK_EQ(rv, 0) << "Could not find any created Java VM"; CHECK_EQ(num_vms, 1) << "No VMs returned"; diff --git a/be/src/vec/functions/function_java_udf.cpp b/be/src/vec/functions/function_java_udf.cpp index ba7f58259aeb4b..00be6c5c424309 100644 --- a/be/src/vec/functions/function_java_udf.cpp +++ b/be/src/vec/functions/function_java_udf.cpp @@ -49,10 +49,11 @@ JavaFunctionCall::JavaFunctionCall(const TFunction& fn, const DataTypes& argumen Status JavaFunctionCall::prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) { - DCHECK(executor_cl_ == NULL) << "Init() already called!"; - JNIEnv* env; + JNIEnv* env = nullptr; RETURN_IF_ERROR(JniUtil::GetJNIEnv(&env)); - if (env == NULL) return Status::InternalError("Failed to get/create JVM"); + if (env == nullptr) { + return Status::InternalError("Failed to get/create JVM"); + } RETURN_IF_ERROR(JniUtil::GetGlobalClassRef(env, EXECUTOR_CLASS, &executor_cl_)); executor_ctor_id_ = env->GetMethodID(executor_cl_, "", EXECUTOR_CTOR_SIGNATURE); RETURN_ERROR_IF_EXC(env); @@ -101,7 +102,7 @@ Status JavaFunctionCall::prepare(FunctionContext* context, Status JavaFunctionCall::execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t num_rows, bool dry_run) { - JNIEnv* env; + JNIEnv* env = nullptr; RETURN_IF_ERROR(JniUtil::GetJNIEnv(&env)); JniContext* jni_ctx = reinterpret_cast( context->get_function_state(FunctionContext::THREAD_LOCAL)); @@ -109,6 +110,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, for (size_t col_idx : arguments) { ColumnWithTypeAndName& column = block.get_by_position(col_idx); auto col = column.column->convert_to_full_column_if_const(); + auto& col_type = column.type; if (!_argument_types[arg_idx]->equals(*column.type)) { return Status::InvalidArgument(strings::Substitute( "$0-th input column's type $1 does not equal to required type $2", arg_idx, @@ -117,6 +119,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, auto data_col = col; if (auto* nullable = check_and_get_column(*col)) { data_col = nullable->get_nested_column_ptr(); + col_type = remove_nullable(col_type); auto null_col = check_and_get_column>(nullable->get_null_map_column_ptr()); jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] = @@ -124,12 +127,15 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, } else { jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] = -1; } - if (const ColumnString* str_col = check_and_get_column(data_col.get())) { + WhichDataType type(col_type); + if (type.is_string_or_fixed_string()) { + const ColumnString* str_col = assert_cast(data_col.get()); jni_ctx->input_values_buffer_ptr.get()[arg_idx] = reinterpret_cast(str_col->get_chars().data()); jni_ctx->input_offsets_ptrs.get()[arg_idx] = reinterpret_cast(str_col->get_offsets().data()); - } else if (data_col->is_numeric()) { + } else if (type.is_int() || type.is_uint() || type.is_float() || + type.is_date_or_datetime() || type.is_decimal()) { jni_ctx->input_values_buffer_ptr.get()[arg_idx] = reinterpret_cast(data_col->get_raw_data().data); } else { @@ -151,7 +157,8 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, *(jni_ctx->output_null_value) = reinterpret_cast(null_col->get_data().data()); #ifndef EVALUATE_JAVA_UDF #define EVALUATE_JAVA_UDF \ - if (const ColumnString* str_col = check_and_get_column(data_col.get())) { \ + if (data_col->is_column_string()) { \ + const ColumnString* str_col = assert_cast(data_col.get()); \ ColumnString::Chars& chars = const_cast(str_col->get_chars()); \ ColumnString::Offsets& offsets = \ const_cast(str_col->get_offsets()); \ @@ -177,7 +184,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, env->CallNonvirtualVoidMethodA(jni_ctx->executor, executor_cl_, executor_evaluate_id_, \ nullptr); \ } \ - } else if (data_col->is_numeric()) { \ + } else if (data_col->is_numeric() || data_col->is_column_decimal()) { \ data_col->reserve(num_rows); \ data_col->resize(num_rows); \ *(jni_ctx->output_value_buffer) = \ @@ -205,7 +212,7 @@ Status JavaFunctionCall::close(FunctionContext* context, FunctionContext::FunctionStateScope scope) { JniContext* jni_ctx = reinterpret_cast( context->get_function_state(FunctionContext::THREAD_LOCAL)); - if (jni_ctx != NULL) { + if (jni_ctx != nullptr) { delete jni_ctx; context->set_function_state(FunctionContext::THREAD_LOCAL, nullptr); } diff --git a/docs/en/ecosystem/udf/java-user-defined-function.md b/docs/en/ecosystem/udf/java-user-defined-function.md index 05d3be39987c2a..9ccc87aa73f2f8 100644 --- a/docs/en/ecosystem/udf/java-user-defined-function.md +++ b/docs/en/ecosystem/udf/java-user-defined-function.md @@ -44,6 +44,23 @@ To use Java UDF, the main entry of UDF must be the `evaluate` function. This is It is worth mentioning that this example is not only the Java UDF supported by Doris, but also the UDF supported by Hive, that's to say, for users, Hive UDF can be directly migrated to Doris. +#### Type correspondence + +|UDF Type|Argument Type| +|----|---------| +|TinyInt|TinyIntVal| +|SmallInt|Short| +|Int|Integer| +|BigInt|Long| +|LargeInt|BigInteger| +|Float|Float| +|Double|Double| +|Date|LocalDate| +|Datetime|LocalDateTime| +|Char|String| +|Varchar|String| +|Decimal|BigDecimal| + ## Create UDF Currently, UDAF and UDTF are not supported. @@ -85,6 +102,6 @@ Examples of Java UDF are provided in the `samples/doris-demo/java-udf-demo/` dir ## Unsupported Use Case At present, Java UDF is still in the process of continuous development, so some features are **not completed**. -1. Complex data types (date, HLL, bitmap) are not supported. +1. Complex data types (HLL, bitmap) are not supported. 2. Memory management and statistics of JVM and Doris have not been unified. diff --git a/docs/zh-CN/ecosystem/udf/java-user-defined-function.md b/docs/zh-CN/ecosystem/udf/java-user-defined-function.md index ea810835201b6c..8306e842176ce7 100644 --- a/docs/zh-CN/ecosystem/udf/java-user-defined-function.md +++ b/docs/zh-CN/ecosystem/udf/java-user-defined-function.md @@ -43,6 +43,24 @@ Java UDF 为用户提供UDF编写的Java接口,以方便用户使用Java语言 使用Java代码编写UDF,UDF的主入口必须为 `evaluate` 函数。这一点与Hive等其他引擎保持一致。在本示例中,我们编写了 `AddOne` UDF来完成对整型输入进行加一的操作。 值得一提的是,本例不只是Doris支持的Java UDF,同时还是Hive支持的UDF,也就是说,对于用户来讲,Hive UDF是可以直接迁移至Doris的。 +#### 类型对应关系 + +|UDF Type|Argument Type| +|----|---------| +|TinyInt|TinyIntVal| +|SmallInt|Short| +|Int|Integer| +|BigInt|Long| +|LargeInt|BigInteger| +|Float|Float| +|Double|Double| +|Date|LocalDate| +|Datetime|LocalDateTime| +|Char|String| +|Varchar|String| +|Decimal|BigDecimal| + + ## 创建 UDF 目前暂不支持 UDAF 和 UDTF @@ -84,6 +102,6 @@ UDF 的使用与普通的函数方式一致,唯一的区别在于,内置函 ## 暂不支持的场景 当前Java UDF仍然处在持续的开发过程中,所以部分功能**尚不完善**。包括: -1. 不支持复杂数据类型(Date,HLL,Bitmap) +1. 不支持复杂数据类型(HLL,Bitmap) 2. 尚未统一JVM和Doris的内存管理以及统计信息 diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java index 3036f1bde32061..89f5603fd78831 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java @@ -57,11 +57,15 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.lang.reflect.Parameter; +import java.math.BigDecimal; +import java.math.BigInteger; import java.net.MalformedURLException; import java.net.URL; import java.net.URLClassLoader; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; +import java.time.LocalDate; +import java.time.LocalDateTime; import java.util.List; import java.util.Map; import java.util.Set; @@ -368,6 +372,10 @@ private void analyzeJavaUdf(String clazz) throws AnalysisException { .put(PrimitiveType.CHAR, Sets.newHashSet(String.class)) .put(PrimitiveType.VARCHAR, Sets.newHashSet(String.class)) .put(PrimitiveType.STRING, Sets.newHashSet(String.class)) + .put(PrimitiveType.DATE, Sets.newHashSet(LocalDate.class)) + .put(PrimitiveType.DATETIME, Sets.newHashSet(LocalDateTime.class)) + .put(PrimitiveType.LARGEINT, Sets.newHashSet(BigInteger.class)) + .put(PrimitiveType.DECIMALV2, Sets.newHashSet(BigDecimal.class)) .build(); private void checkUdfType(Class clazz, Method method, Type expType, Class pType, String pname) diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java index 932695057cd789..151d58b40ecb36 100644 --- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java +++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java @@ -33,11 +33,16 @@ import java.io.IOException; import java.lang.reflect.Constructor; import java.lang.reflect.Method; +import java.math.BigDecimal; +import java.math.BigInteger; import java.net.MalformedURLException; import java.net.URL; import java.net.URLClassLoader; import java.nio.charset.StandardCharsets; +import java.time.LocalDate; +import java.time.LocalDateTime; import java.util.ArrayList; +import java.util.Arrays; public class UdfExecutor { private static final Logger LOG = Logger.getLogger(UdfExecutor.class); @@ -95,7 +100,11 @@ public enum JavaUdfDataType { DOUBLE("DOUBLE", TPrimitiveType.DOUBLE, 8), CHAR("CHAR", TPrimitiveType.CHAR, 0), VARCHAR("VARCHAR", TPrimitiveType.VARCHAR, 0), - STRING("STRING", TPrimitiveType.STRING, 0); + STRING("STRING", TPrimitiveType.STRING, 0), + DATE("DATE", TPrimitiveType.DATE, 8), + DATETIME("DATETIME", TPrimitiveType.DATETIME, 8), + LARGEINT("LARGEINT", TPrimitiveType.LARGEINT, 16), + DECIMALV2("DECIMALV2", TPrimitiveType.DECIMALV2, 16); private final String description; private final TPrimitiveType thriftType; @@ -139,6 +148,14 @@ public static JavaUdfDataType getType(Class c) { return JavaUdfDataType.CHAR; } else if (c == String.class) { return JavaUdfDataType.STRING; + } else if (c == LocalDate.class) { + return JavaUdfDataType.DATE; + } else if (c == LocalDateTime.class) { + return JavaUdfDataType.DATETIME; + } else if (c == BigInteger.class) { + return JavaUdfDataType.LARGEINT; + } else if (c == BigDecimal.class) { + return JavaUdfDataType.DECIMALV2; } return JavaUdfDataType.INVALID_TYPE; } @@ -162,14 +179,12 @@ public static boolean isSupported(Type t) { */ public UdfExecutor(byte[] thriftParams) throws Exception { TJavaUdfExecutorCtorParams request = new TJavaUdfExecutorCtorParams(); - TDeserializer deserializer = new TDeserializer(PROTOCOL_FACTORY); try { deserializer.deserialize(request, thriftParams); } catch (TException e) { throw new InternalException(e.getMessage()); } - String className = request.fn.scalar_fn.symbol; String jarFile = request.location; Type retType = UdfUtils.fromThrift(request.fn.ret_type, 0).first; @@ -310,37 +325,99 @@ private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException switch (retType) { case BOOLEAN: { boolean val = (boolean) obj; - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), val ? (byte) 1 : 0); + UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + val ? (byte) 1 : 0); return true; } case TINYINT: { - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), (byte) obj); + UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + (byte) obj); return true; } case SMALLINT: { - UdfUtils.UNSAFE.putShort(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), (short) obj); + UdfUtils.UNSAFE.putShort(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + (short) obj); return true; } case INT: { - UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), (int) obj); + UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + (int) obj); return true; } case BIGINT: { - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), (long) obj); + UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + (long) obj); return true; } case FLOAT: { - UdfUtils.UNSAFE.putFloat(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), (float) obj); + UdfUtils.UNSAFE.putFloat(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + (float) obj); return true; } case DOUBLE: { - UdfUtils.UNSAFE.putDouble(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), (double) obj); + UdfUtils.UNSAFE.putDouble(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + (double) obj); + return true; + } + case DATE: { + LocalDate date = (LocalDate) obj; + long time = + convertDateTimeToLong(date.getYear(), date.getMonthValue(), date.getDayOfMonth(), 0, 0, 0, + true); + UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + time); + return true; + } + case DATETIME: { + LocalDateTime date = (LocalDateTime) obj; + long time = + convertDateTimeToLong(date.getYear(), date.getMonthValue(), date.getDayOfMonth(), + date.getHour(), + date.getMinute(), date.getSecond(), false); + UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + time); + return true; + } + case LARGEINT: { + BigInteger data = (BigInteger) obj; + byte[] bytes = convertByteOrder(data.toByteArray()); + + //here value is 16 bytes, so if result data greater than the maximum of 16 bytes + //it will return a wrong num to backend; + byte[] value = new byte[16]; + //check data is negative + if (data.signum() == -1) { + Arrays.fill(value, (byte) -1); + } + for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { + value[index] = bytes[index]; + } + + UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, + UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), value.length); + return true; + } + case DECIMALV2: { + BigInteger data = ((BigDecimal) obj).unscaledValue(); + byte[] bytes = convertByteOrder(data.toByteArray()); + + byte[] value = new byte[16]; + if (data.signum() == -1) { + Arrays.fill(value, (byte) -1); + } + + for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { + value[index] = bytes[index]; + } + + UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, + UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), value.length); return true; } case CHAR: case VARCHAR: - case STRING: - long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr); + case STRING: { + long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr_); byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8); if (outputOffset + bytes.length + 1 > bufferSize) { return false; @@ -352,6 +429,7 @@ private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + outputOffset - bytes.length - 1, bytes.length); return true; + } default: throw new UdfRuntimeException("Unsupported return type: " + retType); } @@ -366,48 +444,81 @@ private void allocateInputObjects(long row) throws UdfRuntimeException { for (int i = 0; i < argTypes.length; ++i) { switch (argTypes[i]) { case BOOLEAN: - inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset( - inputBufferPtrs, i)) + row); + inputObjects_[i] = UdfUtils.UNSAFE.getBoolean(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + row); break; case TINYINT: - inputObjects[i] = UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset( - inputBufferPtrs, i)) + row); + inputObjects_[i] = UdfUtils.UNSAFE.getByte(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + row); break; case SMALLINT: - inputObjects[i] = UdfUtils.UNSAFE.getShort(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset( - inputBufferPtrs, i)) + 2L * row); + inputObjects_[i] = UdfUtils.UNSAFE.getShort(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 2L * row); break; case INT: - inputObjects[i] = UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset( - inputBufferPtrs, i)) + 4L * row); + inputObjects_[i] = UdfUtils.UNSAFE.getInt(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 4L * row); break; case BIGINT: - inputObjects[i] = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset( - inputBufferPtrs, i)) + 8L * row); + inputObjects_[i] = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); break; case FLOAT: - inputObjects[i] = UdfUtils.UNSAFE.getFloat(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset( - inputBufferPtrs, i)) + 4L * row); + inputObjects_[i] = UdfUtils.UNSAFE.getFloat(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 4L * row); break; case DOUBLE: - inputObjects[i] = UdfUtils.UNSAFE.getDouble(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset( - inputBufferPtrs, i)) + 8L * row); + inputObjects_[i] = UdfUtils.UNSAFE.getDouble(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); + break; + case DATE: { + long data = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); + inputObjects_[i] = convertToDate(data); + break; + } + case DATETIME: { + long data = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); + inputObjects_[i] = convertToDateTime(data); break; + } + case LARGEINT: { + long base = + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 16L * row; + byte[] bytes = new byte[16]; + UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16); + + inputObjects_[i] = new BigInteger(convertByteOrder(bytes)); + break; + } + case DECIMALV2: { + long base = + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 16L * row; + byte[] bytes = new byte[16]; + UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16); + + BigInteger value = new BigInteger(convertByteOrder(bytes)); + inputObjects_[i] = new BigDecimal(value, 9); + break; + } case CHAR: case VARCHAR: - case STRING: - long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * row)); + case STRING: { + long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputOffsetsPtrs_, i)) + 4L * row)); long numBytes = row == 0 ? offset - 1 : offset - Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - 1))) - 1; - long base = row == 0 ? UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) : - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + offset - numBytes - 1; + UdfUtils.getAddressAtOffset(inputOffsetsPtrs_, i)) + 4L * (row - 1))) - 1; + long base = + row == 0 ? UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) : + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + + offset - numBytes - 1; byte[] bytes = new byte[(int) numBytes]; UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes); inputObjects[i] = new String(bytes, StandardCharsets.UTF_8); break; + } default: throw new UdfRuntimeException("Unsupported argument type: " + argTypes[i]); } @@ -416,7 +527,7 @@ private void allocateInputObjects(long row) throws UdfRuntimeException { private URLClassLoader getClassLoader(String jarPath) throws MalformedURLException { URL url = new File(jarPath).toURI().toURL(); - return URLClassLoader.newInstance(new URL[]{url}, getClass().getClassLoader()); + return URLClassLoader.newInstance(new URL[] {url}, getClass().getClassLoader()); } /** @@ -530,4 +641,65 @@ private void init(String jarPath, String udfPath, throw new UdfRuntimeException("Unable to call create UDF instance.", e); } } -} + + // input is a 64bit num from backend, and then get year, month, day, hour, minus, second by the order of bits + // return a new LocalDateTime data to evaluate method; + private LocalDateTime convertToDateTime(long date) { + int year = (int) (date >> 48); + int yearMonth = (int) (date >> 40); + int yearMonthDay = (int) (date >> 32); + + int month = (yearMonth & 0XFF); + int day = (yearMonthDay & 0XFF); + + int hourMinuteSecond = (int) (date % (1 << 31)); + int minuteTypeNeg = (hourMinuteSecond % (1 << 16)); + + int hour = (hourMinuteSecond >> 24); + int minute = ((hourMinuteSecond >> 16) & 0XFF); + int second = (minuteTypeNeg >> 4); + //here don't need those bits are type = ((minus_type_neg >> 1) & 0x7); + + LocalDateTime value = LocalDateTime.of(year, month, day, hour, minute, second); + return value; + } + + private LocalDate convertToDate(long date) { + int year = (int) (date >> 48); + int yearMonth = (int) (date >> 40); + int yearMonthDay = (int) (date >> 32); + + int month = (yearMonth & 0XFF); + int day = (yearMonthDay & 0XFF); + LocalDate value = LocalDate.of(year, month, day); + return value; + } + + //input is the second, minute, hours, day , month and year respectively + //and then combining all num to a 64bit value return to backend; + long convertDateTimeToLong(int year, int month, int day, int hour, int minute, int second, Boolean isDate) { + long time = 0; + time = time + year; + time = (time << 8) + month; + time = (time << 8) + day; + time = (time << 8) + hour; + time = (time << 8) + minute; + time = (time << 12) + second; + int type = isDate ? 2 : 3; + time = (time << 3) + type; + //this bit is int neg = 0; + time = (time << 1); + return time; + } + + // Change the order of the bytes, Because JVM is Big-Endian , x86 is Little-Endian + private byte[] convertByteOrder(byte[] bytes) { + int length = bytes.length; + for (int i = 0; i < length / 2; ++i) { + byte temp = bytes[i]; + bytes[i] = bytes[length - 1 - i]; + bytes[length - 1 - i] = temp; + } + return bytes; + } +} \ No newline at end of file From 59214d40a54fd27187a1e0db0679bf6b11de323c Mon Sep 17 00:00:00 2001 From: zhangstar333 <2561612514@qq.com> Date: Thu, 12 May 2022 11:25:18 +0800 Subject: [PATCH 2/4] add java-udf fe unit test --- be/src/util/jni-util.cpp | 2 +- be/src/vec/functions/function_java_udf.cpp | 9 +- .../org/apache/doris/udf/UdfExecutor.java | 2 +- .../org/apache/doris/udf/DateTimeUdf.java | 30 ++ .../java/org/apache/doris/udf/DecimalUdf.java | 31 ++ .../org/apache/doris/udf/LargeIntUdf.java | 31 ++ .../org/apache/doris/udf/UdfExecutorTest.java | 300 ++++++++++++++++++ 7 files changed, 397 insertions(+), 8 deletions(-) create mode 100644 fe/java-udf/src/test/java/org/apache/doris/udf/DateTimeUdf.java create mode 100644 fe/java-udf/src/test/java/org/apache/doris/udf/DecimalUdf.java create mode 100644 fe/java-udf/src/test/java/org/apache/doris/udf/LargeIntUdf.java diff --git a/be/src/util/jni-util.cpp b/be/src/util/jni-util.cpp index 94355ec9513d7c..237139ce80f693 100644 --- a/be/src/util/jni-util.cpp +++ b/be/src/util/jni-util.cpp @@ -19,7 +19,7 @@ #ifdef LIBJVM #include -#include "jni_md.h" +#include #include #include "gutil/once.h" diff --git a/be/src/vec/functions/function_java_udf.cpp b/be/src/vec/functions/function_java_udf.cpp index 00be6c5c424309..f22321e69de019 100644 --- a/be/src/vec/functions/function_java_udf.cpp +++ b/be/src/vec/functions/function_java_udf.cpp @@ -110,7 +110,6 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, for (size_t col_idx : arguments) { ColumnWithTypeAndName& column = block.get_by_position(col_idx); auto col = column.column->convert_to_full_column_if_const(); - auto& col_type = column.type; if (!_argument_types[arg_idx]->equals(*column.type)) { return Status::InvalidArgument(strings::Substitute( "$0-th input column's type $1 does not equal to required type $2", arg_idx, @@ -119,7 +118,6 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, auto data_col = col; if (auto* nullable = check_and_get_column(*col)) { data_col = nullable->get_nested_column_ptr(); - col_type = remove_nullable(col_type); auto null_col = check_and_get_column>(nullable->get_null_map_column_ptr()); jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] = @@ -127,15 +125,14 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, } else { jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] = -1; } - WhichDataType type(col_type); - if (type.is_string_or_fixed_string()) { + + if (data_col->is_column_string()) { const ColumnString* str_col = assert_cast(data_col.get()); jni_ctx->input_values_buffer_ptr.get()[arg_idx] = reinterpret_cast(str_col->get_chars().data()); jni_ctx->input_offsets_ptrs.get()[arg_idx] = reinterpret_cast(str_col->get_offsets().data()); - } else if (type.is_int() || type.is_uint() || type.is_float() || - type.is_date_or_datetime() || type.is_decimal()) { + } else if (data_col->is_numeric() || data_col->is_column_decimal()) { jni_ctx->input_values_buffer_ptr.get()[arg_idx] = reinterpret_cast(data_col->get_raw_data().data); } else { diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java index 151d58b40ecb36..d1d5a91edeadf5 100644 --- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java +++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java @@ -677,7 +677,7 @@ private LocalDate convertToDate(long date) { //input is the second, minute, hours, day , month and year respectively //and then combining all num to a 64bit value return to backend; - long convertDateTimeToLong(int year, int month, int day, int hour, int minute, int second, Boolean isDate) { + private long convertDateTimeToLong(int year, int month, int day, int hour, int minute, int second, boolean isDate) { long time = 0; time = time + year; time = (time << 8) + month; diff --git a/fe/java-udf/src/test/java/org/apache/doris/udf/DateTimeUdf.java b/fe/java-udf/src/test/java/org/apache/doris/udf/DateTimeUdf.java new file mode 100644 index 00000000000000..98eaa35fbe873f --- /dev/null +++ b/fe/java-udf/src/test/java/org/apache/doris/udf/DateTimeUdf.java @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.udf; + +import java.time.LocalDateTime; + +public class DateTimeUdf { + /** + * input argument of datetime. + * return year + */ + public int evaluate(LocalDateTime a) { + return a.getYear(); + } +} diff --git a/fe/java-udf/src/test/java/org/apache/doris/udf/DecimalUdf.java b/fe/java-udf/src/test/java/org/apache/doris/udf/DecimalUdf.java new file mode 100644 index 00000000000000..8ec393ef929bc1 --- /dev/null +++ b/fe/java-udf/src/test/java/org/apache/doris/udf/DecimalUdf.java @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.udf; + +import java.math.BigDecimal; + +public class DecimalUdf { + /** + * a input argument of decimal. + * b input argument of decimal + * sum of a and b + */ + public BigDecimal evaluate(BigDecimal a, BigDecimal b) { + return a.add(b); + } +} diff --git a/fe/java-udf/src/test/java/org/apache/doris/udf/LargeIntUdf.java b/fe/java-udf/src/test/java/org/apache/doris/udf/LargeIntUdf.java new file mode 100644 index 00000000000000..2a12ee043d45bd --- /dev/null +++ b/fe/java-udf/src/test/java/org/apache/doris/udf/LargeIntUdf.java @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.udf; + +import java.math.BigInteger; + +public class LargeIntUdf { + /** + * input argument of largeint. + * input argument of largeint + * sum of a and b + */ + public BigInteger evaluate(BigInteger a, BigInteger b) { + return a.add(b); + } +} diff --git a/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java b/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java index d72c62149124b4..ab2a5af1b3f598 100644 --- a/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java +++ b/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java @@ -32,12 +32,198 @@ import org.apache.thrift.protocol.TBinaryProtocol; import org.junit.Test; +import java.math.BigDecimal; +import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; public class UdfExecutorTest { + + @Test + public void testDateTimeUdf() throws Exception { + TScalarFunction scalarFunction = new TScalarFunction(); + scalarFunction.symbol = "org.apache.doris.udf.DateTimeUdf"; + + TFunction fn = new TFunction(); + fn.setBinaryType(TFunctionBinaryType.JAVA_UDF); + TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR); + typeNode.setScalarType(new TScalarType(TPrimitiveType.INT)); + fn.setRetType(new TTypeDesc(Collections.singletonList(typeNode))); + + TTypeNode typeNodeArg = new TTypeNode(TTypeNodeType.SCALAR); + typeNodeArg.setScalarType(new TScalarType(TPrimitiveType.DATETIME)); + TTypeDesc typeDescArg = new TTypeDesc(Collections.singletonList(typeNodeArg)); + fn.arg_types = Arrays.asList(typeDescArg); + + fn.scalar_fn = scalarFunction; + fn.name = new TFunctionName("DateTimeUdf"); + + long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(4); + int batchSize = 10; + UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); + + TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); + params.setBatchSizePtr(batchSizePtr); + params.setFn(fn); + + long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); + long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); + long outputBuffer = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); + long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); + + UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); + UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); + + params.setOutputBufferPtr(outputBufferPtr); + params.setOutputNullPtr(outputNullPtr); + + int numCols = 1; + long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); + long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); + + long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(8 * batchSize); + long inputNull1 = UdfUtils.UNSAFE.allocateMemory(batchSize); + + UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1); + UdfUtils.UNSAFE.putLong(inputNullPtr, inputNull1); + + long[] inputLongDateTime = + new long[] {562960991655690406L, 563242466632401062L, 563523941609111718L, 563805416585822374L, + 564086891562533030L, 564368366539243686L, 564649841515954342L, 564931316492664998L, + 565212791469375654L, 565494266446086310L}; + + for (int i = 0; i < batchSize; ++i) { + UdfUtils.UNSAFE.putLong(null, inputBuffer1 + i * 8, inputLongDateTime[i]); + UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 0); + } + + params.setInputBufferPtrs(inputBufferPtr); + params.setInputNullsPtrs(inputNullPtr); + params.setInputOffsetsPtrs(0); + + TBinaryProtocol.Factory factory = new TBinaryProtocol.Factory(); + TSerializer serializer = new TSerializer(factory); + + UdfExecutor executor = new UdfExecutor(serializer.serialize(params)); + executor.evaluate(); + + for (int i = 0; i < batchSize; ++i) { + assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); + assert (UdfUtils.UNSAFE.getInt(outputBuffer + 4 * i) == (2000 + i)); + } + } + + @Test + public void testDecimalUdf() throws Exception { + TScalarFunction scalarFunction = new TScalarFunction(); + scalarFunction.symbol = "org.apache.doris.udf.DecimalUdf"; + TFunction fn = new TFunction(); + fn.binary_type = TFunctionBinaryType.JAVA_UDF; + TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR); + TScalarType scalarType = new TScalarType(TPrimitiveType.DECIMALV2); + scalarType.setScale(9); + scalarType.setPrecision(27); + typeNode.scalar_type = scalarType; + TTypeDesc typeDesc = new TTypeDesc(Collections.singletonList(typeNode)); + fn.ret_type = typeDesc; + fn.arg_types = Arrays.asList(typeDesc, typeDesc); + fn.scalar_fn = scalarFunction; + fn.name = new TFunctionName("DecimalUdf"); + + long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(8); + int batchSize = 10; + UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); + + TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); + params.setBatchSizePtr(batchSizePtr); + params.setFn(fn); + + long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); + long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); + + long outputBuffer = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); + long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); + + UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); + UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); + + params.setOutputBufferPtr(outputBufferPtr); + params.setOutputNullPtr(outputNullPtr); + + int numCols = 2; + long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); + long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); + + long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); + long inputNull1 = UdfUtils.UNSAFE.allocateMemory(batchSize); + + long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); + long inputNull2 = UdfUtils.UNSAFE.allocateMemory(batchSize); + + UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1); + UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2); + UdfUtils.UNSAFE.putLong(inputNullPtr, inputNull1); + UdfUtils.UNSAFE.putLong(inputNullPtr + 8, inputNull2); + + long[] inputLong = + new long[] {562960991655690406L, 563242466632401062L, 563523941609111718L, 563805416585822374L, + 564086891562533030L, 564368366539243686L, 564649841515954342L, 564931316492664998L, + 565212791469375654L, 565494266446086310L}; + + BigDecimal[] decimalArray = new BigDecimal[10]; + for (int i = 0; i < batchSize; ++i) { + BigInteger temp = BigInteger.valueOf(inputLong[i]); + decimalArray[i] = new BigDecimal(temp, 9); + } + + BigDecimal decimal2 = new BigDecimal(BigInteger.valueOf(0L), 9); + byte[] intput2 = convertByteOrder(decimal2.unscaledValue().toByteArray()); + byte[] value2 = new byte[16]; + if (decimal2.signum() == -1) { + Arrays.fill(value2, (byte) -1); + } + for (int index = 0; index < Math.min(intput2.length, value2.length); ++index) { + value2[index] = intput2[index]; + } + + for (int i = 0; i < batchSize; ++i) { + byte[] intput1 = convertByteOrder(decimalArray[i].unscaledValue().toByteArray()); + byte[] value1 = new byte[16]; + if (decimalArray[i].signum() == -1) { + Arrays.fill(value1, (byte) -1); + } + for (int index = 0; index < Math.min(intput1.length, value1.length); ++index) { + value1[index] = intput1[index]; + } + UdfUtils.copyMemory(value1, UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1 + i * 16, value1.length); + UdfUtils.copyMemory(value2, UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2 + i * 16, value2.length); + UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 0); + UdfUtils.UNSAFE.putByte(null, inputNull2 + i, (byte) 0); + } + + params.setInputBufferPtrs(inputBufferPtr); + params.setInputNullsPtrs(inputNullPtr); + params.setInputOffsetsPtrs(0); + + TBinaryProtocol.Factory factory = new TBinaryProtocol.Factory(); + TSerializer serializer = new TSerializer(factory); + + UdfExecutor udfExecutor = new UdfExecutor(serializer.serialize(params)); + udfExecutor.evaluate(); + + for (int i = 0; i < batchSize; ++i) { + byte[] bytes = new byte[16]; + assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); + UdfUtils.copyMemory(null, outputBuffer + 16 * i, bytes, UdfUtils.BYTE_ARRAY_OFFSET, bytes.length); + + BigInteger interger = new BigInteger(convertByteOrder(bytes)); + BigDecimal result = new BigDecimal(interger, 9); + assert (result.equals(decimalArray[i])); + } + } + @Test public void testConstantOneUdf() throws Exception { TScalarFunction scalarFunction = new TScalarFunction(); @@ -303,4 +489,118 @@ public void testStringConcatUdf() throws Exception { assert (UdfUtils.UNSAFE.getByte(null, outputNull + i) == 0); } } + + @Test + public void testLargeIntUdf() throws Exception { + TScalarFunction scalarFunction = new TScalarFunction(); + scalarFunction.symbol = "org.apache.doris.udf.LargeIntUdf"; + TFunction fn = new TFunction(); + fn.binary_type = TFunctionBinaryType.JAVA_UDF; + TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR); + typeNode.scalar_type = new TScalarType(TPrimitiveType.LARGEINT); + + TTypeDesc typeDesc = new TTypeDesc(Collections.singletonList(typeNode)); + + fn.ret_type = typeDesc; + fn.arg_types = Arrays.asList(typeDesc, typeDesc); + fn.scalar_fn = scalarFunction; + fn.name = new TFunctionName("LargeIntUdf"); + + long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(8); + int batchSize = 10; + UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); + + TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); + params.setBatchSizePtr(batchSizePtr); + params.setFn(fn); + + long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); + long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); + + long outputBuffer = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); + long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); + + UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); + UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); + + params.setOutputBufferPtr(outputBufferPtr); + params.setOutputNullPtr(outputNullPtr); + + int numCols = 2; + long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); + long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); + + long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); + long inputNull1 = UdfUtils.UNSAFE.allocateMemory(batchSize); + + long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); + long inputNull2 = UdfUtils.UNSAFE.allocateMemory(batchSize); + + UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1); + UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2); + UdfUtils.UNSAFE.putLong(inputNullPtr, inputNull1); + UdfUtils.UNSAFE.putLong(inputNullPtr + 8, inputNull2); + + long[] inputLong = + new long[] {562960991655690406L, 563242466632401062L, 563523941609111718L, 563805416585822374L, + 564086891562533030L, 564368366539243686L, 564649841515954342L, 564931316492664998L, + 565212791469375654L, 565494266446086310L}; + + BigInteger[] integerArray = new BigInteger[10]; + for (int i = 0; i < batchSize; ++i) { + integerArray[i] = BigInteger.valueOf(inputLong[i]); + } + BigInteger integer2 = BigInteger.valueOf(1L); + byte[] intput2 = convertByteOrder(integer2.toByteArray()); + byte[] value2 = new byte[16]; + if (integer2.signum() == -1) { + Arrays.fill(value2, (byte) -1); + } + for (int index = 0; index < Math.min(intput2.length, value2.length); ++index) { + value2[index] = intput2[index]; + } + + for (int i = 0; i < batchSize; ++i) { + byte[] intput1 = convertByteOrder(integerArray[i].toByteArray()); + byte[] value1 = new byte[16]; + if (integerArray[i].signum() == -1) { + Arrays.fill(value1, (byte) -1); + } + for (int index = 0; index < Math.min(intput1.length, value1.length); ++index) { + value1[index] = intput1[index]; + } + UdfUtils.copyMemory(value1, UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1 + i * 16, value1.length); + UdfUtils.copyMemory(value2, UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2 + i * 16, value2.length); + UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 0); + UdfUtils.UNSAFE.putByte(null, inputNull2 + i, (byte) 0); + } + + params.setInputBufferPtrs(inputBufferPtr); + params.setInputNullsPtrs(inputNullPtr); + params.setInputOffsetsPtrs(0); + + TBinaryProtocol.Factory factory = new TBinaryProtocol.Factory(); + TSerializer serializer = new TSerializer(factory); + + UdfExecutor udfExecutor = new UdfExecutor(serializer.serialize(params)); + udfExecutor.evaluate(); + + for (int i = 0; i < batchSize; ++i) { + byte[] bytes = new byte[16]; + assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); + UdfUtils.copyMemory(null, outputBuffer + 16 * i, bytes, UdfUtils.BYTE_ARRAY_OFFSET, bytes.length); + BigInteger result = new BigInteger(convertByteOrder(bytes)); + assert (result.equals(integerArray[i].add(BigInteger.valueOf(1)))); + } + } + + public byte[] convertByteOrder(byte[] bytes) { + int length = bytes.length; + for (int i = 0; i < length / 2; ++i) { + byte temp = bytes[i]; + bytes[i] = bytes[length - 1 - i]; + bytes[length - 1 - i] = temp; + } + return bytes; + } } From ddba686d7ec22ff8bf90531c41c68725c0289232 Mon Sep 17 00:00:00 2001 From: zhangstar333 <2561612514@qq.com> Date: Tue, 17 May 2022 11:02:41 +0800 Subject: [PATCH 3/4] rebase code to master --- .../src/main/java/org/apache/doris/udf/UdfExecutor.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java index d1d5a91edeadf5..2340d71e9ff826 100644 --- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java +++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java @@ -382,7 +382,7 @@ private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException BigInteger data = (BigInteger) obj; byte[] bytes = convertByteOrder(data.toByteArray()); - //here value is 16 bytes, so if result data greater than the maximum of 16 bytes + //here value is 16 bytes, so if result data greater than the maximum of 16 bytes //it will return a wrong num to backend; byte[] value = new byte[16]; //check data is negative @@ -400,7 +400,7 @@ private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException case DECIMALV2: { BigInteger data = ((BigDecimal) obj).unscaledValue(); byte[] bytes = convertByteOrder(data.toByteArray()); - + //TODO: here is maybe overflow also, and may find a better way to handle byte[] value = new byte[16]; if (data.signum() == -1) { Arrays.fill(value, (byte) -1); From e6eb4322b7728e1d63c41f7dd12bd8a9e6cf877f Mon Sep 17 00:00:00 2001 From: zhangstar333 <2561612514@qq.com> Date: Thu, 19 May 2022 19:00:48 +0800 Subject: [PATCH 4/4] rebase to solve conflict --- .../org/apache/doris/udf/UdfExecutor.java | 80 +++++++++---------- .../org/apache/doris/udf/UdfExecutorTest.java | 2 +- 2 files changed, 41 insertions(+), 41 deletions(-) diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java index 2340d71e9ff826..3a702e82d4b446 100644 --- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java +++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java @@ -80,7 +80,7 @@ public class UdfExecutor { // Pre-constructed input objects for the UDF. This minimizes object creation overhead // as these objects are reused across calls to evaluate(). private Object[] inputObjects; - // inputArgs_[i] is either inputObjects_[i] or null + // inputArgs_[i] is either inputObjects[i] or null private Object[] inputArgs; private long outputOffset; @@ -302,7 +302,7 @@ public Method getMethod() { return method; } - // Sets the result object 'obj' into the outputBufferPtr_ and outputNullPtr_ + // Sets the result object 'obj' into the outputBufferPtr and outputNullPtr_ private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException { if (obj == null) { assert (UdfUtils.UNSAFE.getLong(null, outputNullPtr) != -1); @@ -325,37 +325,37 @@ private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException switch (retType) { case BOOLEAN: { boolean val = (boolean) obj; - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), val ? (byte) 1 : 0); return true; } case TINYINT: { - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), (byte) obj); return true; } case SMALLINT: { - UdfUtils.UNSAFE.putShort(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + UdfUtils.UNSAFE.putShort(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), (short) obj); return true; } case INT: { - UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), (int) obj); return true; } case BIGINT: { - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), (long) obj); return true; } case FLOAT: { - UdfUtils.UNSAFE.putFloat(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + UdfUtils.UNSAFE.putFloat(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), (float) obj); return true; } case DOUBLE: { - UdfUtils.UNSAFE.putDouble(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + UdfUtils.UNSAFE.putDouble(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), (double) obj); return true; } @@ -364,7 +364,7 @@ private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException long time = convertDateTimeToLong(date.getYear(), date.getMonthValue(), date.getDayOfMonth(), 0, 0, 0, true); - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); return true; } @@ -374,7 +374,7 @@ private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException convertDateTimeToLong(date.getYear(), date.getMonthValue(), date.getDayOfMonth(), date.getHour(), date.getMinute(), date.getSecond(), false); - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); return true; } @@ -394,7 +394,7 @@ private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException } UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), value.length); + UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length); return true; } case DECIMALV2: { @@ -411,13 +411,13 @@ private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException } UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), value.length); + UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length); return true; } case CHAR: case VARCHAR: case STRING: { - long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr_); + long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr); byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8); if (outputOffset + bytes.length + 1 > bufferSize) { return false; @@ -444,75 +444,75 @@ private void allocateInputObjects(long row) throws UdfRuntimeException { for (int i = 0; i < argTypes.length; ++i) { switch (argTypes[i]) { case BOOLEAN: - inputObjects_[i] = UdfUtils.UNSAFE.getBoolean(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + row); + inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row); break; case TINYINT: - inputObjects_[i] = UdfUtils.UNSAFE.getByte(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + row); + inputObjects[i] = UdfUtils.UNSAFE.getByte(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row); break; case SMALLINT: - inputObjects_[i] = UdfUtils.UNSAFE.getShort(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 2L * row); + inputObjects[i] = UdfUtils.UNSAFE.getShort(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 2L * row); break; case INT: - inputObjects_[i] = UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 4L * row); + inputObjects[i] = UdfUtils.UNSAFE.getInt(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 4L * row); break; case BIGINT: - inputObjects_[i] = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); + inputObjects[i] = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row); break; case FLOAT: - inputObjects_[i] = UdfUtils.UNSAFE.getFloat(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 4L * row); + inputObjects[i] = UdfUtils.UNSAFE.getFloat(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 4L * row); break; case DOUBLE: - inputObjects_[i] = UdfUtils.UNSAFE.getDouble(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); + inputObjects[i] = UdfUtils.UNSAFE.getDouble(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row); break; case DATE: { long data = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); - inputObjects_[i] = convertToDate(data); + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row); + inputObjects[i] = convertToDate(data); break; } case DATETIME: { long data = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); - inputObjects_[i] = convertToDateTime(data); + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row); + inputObjects[i] = convertToDateTime(data); break; } case LARGEINT: { long base = - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 16L * row; + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 16L * row; byte[] bytes = new byte[16]; UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16); - inputObjects_[i] = new BigInteger(convertByteOrder(bytes)); + inputObjects[i] = new BigInteger(convertByteOrder(bytes)); break; } case DECIMALV2: { long base = - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 16L * row; + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 16L * row; byte[] bytes = new byte[16]; UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16); BigInteger value = new BigInteger(convertByteOrder(bytes)); - inputObjects_[i] = new BigDecimal(value, 9); + inputObjects[i] = new BigDecimal(value, 9); break; } case CHAR: case VARCHAR: case STRING: { long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs_, i)) + 4L * row)); + UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * row)); long numBytes = row == 0 ? offset - 1 : offset - Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs_, i)) + 4L * (row - 1))) - 1; + UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - 1))) - 1; long base = - row == 0 ? UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) : - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + row == 0 ? UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) : + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + offset - numBytes - 1; byte[] bytes = new byte[(int) numBytes]; UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes); diff --git a/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java b/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java index ab2a5af1b3f598..e999c3d45e9c2c 100644 --- a/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java +++ b/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java @@ -114,7 +114,7 @@ public void testDateTimeUdf() throws Exception { assert (UdfUtils.UNSAFE.getInt(outputBuffer + 4 * i) == (2000 + i)); } } - + @Test public void testDecimalUdf() throws Exception { TScalarFunction scalarFunction = new TScalarFunction();