Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.common.classloader;

import org.apache.doris.common.jni.utils.ExpiringMap;
import org.apache.doris.common.jni.utils.UdfClassCache;

import com.google.common.collect.Streams;
import org.apache.log4j.Logger;
Expand Down Expand Up @@ -45,7 +46,7 @@
public class ScannerLoader {
public static final Logger LOG = Logger.getLogger(ScannerLoader.class);
private static final Map<String, Class<?>> loadedClasses = new HashMap<>();
private static final ExpiringMap<String, ClassLoader> udfLoadedClasses = new ExpiringMap<String, ClassLoader>();
private static final ExpiringMap<String, UdfClassCache> udfLoadedClasses = new ExpiringMap<>();
private static final String CLASS_SUFFIX = ".class";
private static final String LOAD_PACKAGE = "org.apache.doris";

Expand All @@ -65,14 +66,14 @@ public void loadAllScannerJars() {
});
}

public static ClassLoader getUdfClassLoader(String functionSignature) {
public static UdfClassCache getUdfClassLoader(String functionSignature) {
return udfLoadedClasses.get(functionSignature);
}

public static synchronized void cacheClassLoader(String functionSignature, ClassLoader classLoader,
public static synchronized void cacheClassLoader(String functionSignature, UdfClassCache classCache,
long expirationTime) {
LOG.info("cacheClassLoader for: " + functionSignature);
udfLoadedClasses.put(functionSignature, classLoader, expirationTime * 60 * 1000L);
LOG.info("Cache UDF for: " + functionSignature);
udfLoadedClasses.put(functionSignature, classCache, expirationTime * 60 * 1000L);
}

public synchronized void cleanUdfClassLoader(String functionSignature) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.common.jni.utils;

import com.esotericsoftware.reflectasm.MethodAccess;

import java.lang.reflect.Method;

/**
* This class is used for caching the class of UDF.
*/
public class UdfClassCache {
public Class<?> udfClass;
// the index of evaluate() method in the class
public MethodAccess methodAccess;
public int evaluateIndex;
// the method of evaluate() in udf
public Method method;
// the method of prepare() in udf
public Method prepareMethod;
// the argument and return's JavaUdfDataType of evaluate() method.
public JavaUdfDataType[] argTypes;
public JavaUdfDataType retType;
// the class type of the arguments in evaluate() method
public Class[] argClass;
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Pair;
import org.apache.doris.common.classloader.ScannerLoader;
import org.apache.doris.common.exception.InternalException;
import org.apache.doris.common.exception.UdfRuntimeException;
import org.apache.doris.common.jni.utils.JavaUdfDataType;
import org.apache.doris.common.jni.utils.UdfClassCache;
import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.common.jni.vec.ColumnValueConverter;
import org.apache.doris.common.jni.vec.VectorTable;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;

import com.esotericsoftware.reflectasm.MethodAccess;
import com.google.common.base.Joiner;
import com.google.common.base.Strings;
import com.google.common.collect.Lists;
import org.apache.log4j.Logger;

Expand Down Expand Up @@ -140,26 +143,99 @@ private Method findPrepareMethod(Method[] methods) {
return null; // Method not found
}

public ClassLoader getClassLoader(String jarPath, String signature, long expirationTime)
throws MalformedURLException, FileNotFoundException {
ClassLoader loader = null;
if (jarPath == null) {
// for test
loader = ClassLoader.getSystemClassLoader();
} else {
if (isStaticLoad) {
loader = ScannerLoader.getUdfClassLoader(signature);
}
if (loader == null) {
public UdfClassCache getClassCache(String className, String jarPath, String signature, long expirationTime,
Type funcRetType, Type... parameterTypes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I check the failed two test case, not give jarPath in create function.
so pass jarPath = null to BE,
and maybe should give the path to custom_lib?

throws MalformedURLException, FileNotFoundException, ClassNotFoundException, InternalException,
UdfRuntimeException {
UdfClassCache cache = null;
if (isStaticLoad) {
cache = ScannerLoader.getUdfClassLoader(signature);
}
if (cache == null) {
ClassLoader loader;
if (Strings.isNullOrEmpty(jarPath)) {
// if jarPath is empty, which means the UDF jar is located in custom_lib
// and already be loaded when BE start.
// so here we use system class loader to load UDF class.
loader = ClassLoader.getSystemClassLoader();
} else {
ClassLoader parent = getClass().getClassLoader();
classLoader = UdfUtils.getClassLoader(jarPath, parent);
loader = classLoader;
if (isStaticLoad) {
ScannerLoader.cacheClassLoader(signature, loader, expirationTime);
}
cache = new UdfClassCache();
cache.udfClass = Class.forName(className, true, loader);
cache.methodAccess = MethodAccess.get(cache.udfClass);
checkAndCacheUdfClass(className, cache, funcRetType, parameterTypes);
if (isStaticLoad) {
ScannerLoader.cacheClassLoader(signature, cache, expirationTime);
}
}
return cache;
}

private void checkAndCacheUdfClass(String className, UdfClassCache cache, Type funcRetType, Type... parameterTypes)
throws InternalException, UdfRuntimeException {
ArrayList<String> signatures = Lists.newArrayList();
Class<?> c = cache.udfClass;
Method[] methods = c.getMethods();
Method prepareMethod = findPrepareMethod(methods);
if (prepareMethod != null) {
cache.prepareMethod = prepareMethod;
}
for (Method m : methods) {
// By convention, the udf must contain the function "evaluate"
if (!m.getName().equals(UDF_FUNCTION_NAME)) {
continue;
}
signatures.add(m.toGenericString());
cache.argClass = m.getParameterTypes();

// Try to match the arguments
if (cache.argClass.length != parameterTypes.length) {
continue;
}
cache.method = m;
cache.evaluateIndex = cache.methodAccess.getIndex(UDF_FUNCTION_NAME, cache.argClass);
Pair<Boolean, JavaUdfDataType> returnType;
if (cache.argClass.length == 0 && parameterTypes.length == 0) {
// Special case where the UDF doesn't take any input args
returnType = UdfUtils.setReturnType(funcRetType, m.getReturnType());
if (!returnType.first) {
continue;
} else {
cache.retType = returnType.second;
}
cache.argTypes = new JavaUdfDataType[0];
return;
}
returnType = UdfUtils.setReturnType(funcRetType, m.getReturnType());
if (!returnType.first) {
continue;
} else {
cache.retType = returnType.second;
}
Type keyType = cache.retType.getKeyType();
Type valueType = cache.retType.getValueType();
Pair<Boolean, JavaUdfDataType[]> inputType = UdfUtils.setArgTypes(parameterTypes, cache.argClass, false);
if (!inputType.first) {
continue;
} else {
cache.argTypes = inputType.second;
}
cache.retType.setKeyType(keyType);
cache.retType.setValueType(valueType);
return;
}
return loader;
StringBuilder sb = new StringBuilder();
sb.append("Unable to find evaluate function with the correct signature: ")
.append(className)
.append(".evaluate(")
.append(Joiner.on(", ").join(parameterTypes))
.append(")\n")
.append("UDF contains: \n ")
.append(Joiner.on("\n ").join(signatures));
throw new UdfRuntimeException(sb.toString());
}

// Preallocate the input objects that will be passed to the underlying UDF.
Expand All @@ -168,7 +244,6 @@ public ClassLoader getClassLoader(String jarPath, String signature, long expirat
protected void init(TJavaUdfExecutorCtorParams request, String jarPath, Type funcRetType,
Type... parameterTypes) throws UdfRuntimeException {
String className = request.fn.scalar_fn.symbol;
ArrayList<String> signatures = Lists.newArrayList();
try {
if (LOG.isDebugEnabled()) {
LOG.debug("Loading UDF '" + className + "' from " + jarPath);
Expand All @@ -178,66 +253,21 @@ protected void init(TJavaUdfExecutorCtorParams request, String jarPath, Type fun
if (request.getFn().isSetExpirationTime()) {
expirationTime = request.getFn().getExpirationTime();
}
ClassLoader loader = getClassLoader(jarPath, request.getFn().getSignature(), expirationTime);
Class<?> c = Class.forName(className, true, loader);
methodAccess = MethodAccess.get(c);
Constructor<?> ctor = c.getConstructor();
UdfClassCache cache = getClassCache(className, jarPath, request.getFn().getSignature(), expirationTime,
funcRetType, parameterTypes);
methodAccess = cache.methodAccess;
Constructor<?> ctor = cache.udfClass.getConstructor();
udf = ctor.newInstance();
Method[] methods = c.getMethods();
Method prepareMethod = findPrepareMethod(methods);
Method prepareMethod = cache.prepareMethod;
if (prepareMethod != null) {
prepareMethod.invoke(udf);
}
for (Method m : methods) {
// By convention, the udf must contain the function "evaluate"
if (!m.getName().equals(UDF_FUNCTION_NAME)) {
continue;
}
signatures.add(m.toGenericString());
argClass = m.getParameterTypes();

// Try to match the arguments
if (argClass.length != parameterTypes.length) {
continue;
}
method = m;
evaluateIndex = methodAccess.getIndex(UDF_FUNCTION_NAME, argClass);
Pair<Boolean, JavaUdfDataType> returnType;
if (argClass.length == 0 && parameterTypes.length == 0) {
// Special case where the UDF doesn't take any input args
returnType = UdfUtils.setReturnType(funcRetType, m.getReturnType());
if (!returnType.first) {
continue;
} else {
retType = returnType.second;
}
argTypes = new JavaUdfDataType[0];
return;
}
returnType = UdfUtils.setReturnType(funcRetType, m.getReturnType());
if (!returnType.first) {
continue;
} else {
retType = returnType.second;
}
Pair<Boolean, JavaUdfDataType[]> inputType = UdfUtils.setArgTypes(parameterTypes, argClass, false);
if (!inputType.first) {
continue;
} else {
argTypes = inputType.second;
}
return;
}

StringBuilder sb = new StringBuilder();
sb.append("Unable to find evaluate function with the correct signature: ")
.append(className)
.append(".evaluate(")
.append(Joiner.on(", ").join(parameterTypes))
.append(")\n")
.append("UDF contains: \n ")
.append(Joiner.on("\n ").join(signatures));
throw new UdfRuntimeException(sb.toString());
argClass = cache.argClass;
method = cache.method;
evaluateIndex = cache.evaluateIndex;
retType = cache.retType;
argTypes = cache.argTypes;
} catch (MalformedURLException e) {
throw new UdfRuntimeException("Unable to load jar.", e);
} catch (SecurityException e) {
Expand All @@ -255,3 +285,5 @@ protected void init(TJavaUdfExecutorCtorParams request, String jarPath, Type fun
}
}
}


4 changes: 4 additions & 0 deletions fe/fe-common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ under the License.
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-core</artifactId>
</dependency>
<dependency>
<groupId>com.esotericsoftware</groupId>
<artifactId>kryo-shaded</artifactId>
</dependency>
</dependencies>
<build>
<finalName>doris-fe-common</finalName>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,34 @@ suite("nereids_test_javaudf_string") {
qt_select """ SELECT java_udf_string_test(string_col, 2, 3) result FROM ${tableName} ORDER BY result; """
qt_select """ SELECT java_udf_string_test('abcdef', 2, 3), java_udf_string_test('abcdefg', 2, 3) result FROM ${tableName} ORDER BY result; """

// test multi thread
Thread thread1 = new Thread(() -> {
try {
for (int ii = 0; ii < 100; ii++) {
sql """ SELECT java_udf_string_test(varchar_col, 2, 3) result FROM ${tableName} ORDER BY result; """
}
} catch (Exception e) {
log.info(e.getMessage())
Assert.fail();
}
})

Thread thread2 = new Thread(() -> {
try {
for (int ii = 0; ii < 100; ii++) {
sql """ SELECT java_udf_string_test(string_col, 2, 3) result FROM ${tableName} ORDER BY result; """
}
} catch (Exception e) {
log.info(e.getMessage())
Assert.fail();
}
})
sleep(1000L)
thread1.start()
thread2.start()

thread1.join()
thread2.join()

} finally {
try_sql("DROP FUNCTION IF EXISTS java_udf_string_test(string, int, int);")
Expand Down