Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 5 additions & 15 deletions be/src/runtime/user_function_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,10 @@ Status UserFunctionCache::_download_lib(const std::string& url,
return Status::InternalError("fail to open file");
}

std::string real_url = _get_real_url(url);

Md5Digest digest;
HttpClient client;
int64_t file_size = 0;
RETURN_IF_ERROR(client.init(real_url));
RETURN_IF_ERROR(client.init(url));
Status status;
auto download_cb = [&status, &tmp_file, &fp, &digest, &file_size](const void* data,
size_t length) {
Expand All @@ -297,11 +295,10 @@ Status UserFunctionCache::_download_lib(const std::string& url,
digest.digest();
if (!iequal(digest.hex(), entry->checksum)) {
fmt::memory_buffer error_msg;
fmt::format_to(
error_msg,
" The checksum is not equal of {} ({}). The init info of first create entry is:"
"{} But download file check_sum is: {}, file_size is: {}.",
url, real_url, entry->debug_string(), digest.hex(), file_size);
fmt::format_to(error_msg,
" The checksum is not equal of {}. The init info of first create entry is:"
"{} But download file check_sum is: {}, file_size is: {}.",
url, entry->debug_string(), digest.hex(), file_size);
std::string error(fmt::to_string(error_msg));
LOG(WARNING) << error;
return Status::InternalError(error);
Expand All @@ -323,13 +320,6 @@ Status UserFunctionCache::_download_lib(const std::string& url,
return Status::OK();
}

std::string UserFunctionCache::_get_real_url(const std::string& url) {
if (url.find(":/") == std::string::npos) {
return "file://" + config::jdbc_drivers_dir + "/" + url;
}
return url;
}

std::string UserFunctionCache::_get_file_name_from_url(const std::string& url) const {
std::string file_name;
size_t last_slash_pos = url.find_last_of('/');
Expand Down
1 change: 0 additions & 1 deletion be/src/runtime/user_function_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class UserFunctionCache {
const std::string& file_name);
void _destroy_cache_entry(std::shared_ptr<UserFunctionCacheEntry> entry);

std::string _get_real_url(const std::string& url);
std::string _get_file_name_from_url(const std::string& url) const;
std::vector<std::string> _split_string_by_checksum(const std::string& file);

Expand Down
29 changes: 11 additions & 18 deletions be/src/vec/exec/vjdbc_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,7 @@ Status JdbcConnector::open(RuntimeState* state, bool read) {
// Add a scoped cleanup jni reference object. This cleans up local refs made below.
JniLocalFrame jni_frame;
{
std::string local_location;
std::hash<std::string> hash_str;
auto* function_cache = UserFunctionCache::instance();
if (_conn_param.resource_name.empty()) {
// for jdbcExternalTable, _conn_param.resource_name == ""
// so, we use _conn_param.driver_path as key of jarpath
SCOPED_RAW_TIMER(&_jdbc_statistic._load_jar_timer);
RETURN_IF_ERROR(function_cache->get_jarpath(
std::abs((int64_t)hash_str(_conn_param.driver_path)), _conn_param.driver_path,
_conn_param.driver_checksum, &local_location));
} else {
SCOPED_RAW_TIMER(&_jdbc_statistic._load_jar_timer);
RETURN_IF_ERROR(function_cache->get_jarpath(
std::abs((int64_t)hash_str(_conn_param.resource_name)), _conn_param.driver_path,
_conn_param.driver_checksum, &local_location));
}
VLOG_QUERY << "driver local path = " << local_location;
std::string driver_path = _get_real_url(_conn_param.driver_path);

TJdbcExecutorCtorParams ctor_params;
ctor_params.__set_statement(_sql_str);
Expand All @@ -144,7 +128,8 @@ Status JdbcConnector::open(RuntimeState* state, bool read) {
ctor_params.__set_jdbc_user(_conn_param.user);
ctor_params.__set_jdbc_password(_conn_param.passwd);
ctor_params.__set_jdbc_driver_class(_conn_param.driver_class);
ctor_params.__set_driver_path(local_location);
ctor_params.__set_driver_path(driver_path);
ctor_params.__set_jdbc_driver_checksum(_conn_param.driver_checksum);
if (state == nullptr) {
ctor_params.__set_batch_size(read ? 1 : 0);
} else {
Expand Down Expand Up @@ -601,4 +586,12 @@ jobject JdbcConnector::_get_java_table_type(JNIEnv* env, TOdbcTableType::type ta
env->CallStaticObjectMethod(enumClass, findByValueMethod, static_cast<jint>(tableType));
return javaEnumObj;
}

std::string JdbcConnector::_get_real_url(const std::string& url) {
if (url.find(":/") == std::string::npos) {
return "file://" + config::jdbc_drivers_dir + "/" + url;
}
return url;
}

} // namespace doris::vectorized
2 changes: 2 additions & 0 deletions be/src/vec/exec/vjdbc_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ class JdbcConnector : public TableConnector {
int rows);
jobject _get_java_table_type(JNIEnv* env, TOdbcTableType::type tableType);

std::string _get_real_url(const std::string& url);

bool _closed = false;
jclass _executor_factory_clazz;
jclass _executor_clazz;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.doris.jdbc;

import org.apache.doris.common.exception.InternalException;
import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.common.jni.vec.ColumnType;
import org.apache.doris.common.jni.vec.ColumnValueConverter;
import org.apache.doris.common.jni.vec.VectorColumn;
Expand All @@ -27,16 +26,25 @@
import org.apache.doris.thrift.TJdbcOperation;

import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import com.zaxxer.hikari.HikariDataSource;
import org.apache.commons.codec.binary.Hex;
import org.apache.log4j.Logger;
import org.apache.thrift.TDeserializer;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.semver4j.Semver;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Array;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.net.URLConnection;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.Date;
Expand All @@ -57,6 +65,7 @@ public abstract class BaseJdbcExecutor implements JdbcExecutor {
private static final TBinaryProtocol.Factory PROTOCOL_FACTORY = new TBinaryProtocol.Factory();
private HikariDataSource hikariDataSource = null;
private final byte[] hikariDataSourceLock = new byte[0];
private ClassLoader classLoader = null;
private Connection conn = null;
protected JdbcDataSourceConfig config;
protected PreparedStatement preparedStatement = null;
Expand All @@ -68,6 +77,7 @@ public abstract class BaseJdbcExecutor implements JdbcExecutor {
protected int batchSizeNum = 0;
protected int curBlockRows = 0;
protected String jdbcDriverVersion;
private static final Map<URL, ClassLoader> classLoaderMap = Maps.newConcurrentMap();

public BaseJdbcExecutor(byte[] thriftParams) throws Exception {
setJdbcDriverSystemProperties();
Expand All @@ -85,6 +95,7 @@ public BaseJdbcExecutor(byte[] thriftParams) throws Exception {
.setJdbcUrl(request.jdbc_url)
.setJdbcDriverUrl(request.driver_path)
.setJdbcDriverClass(request.jdbc_driver_class)
.setJdbcDriverChecksum(request.jdbc_driver_checksum)
.setBatchSize(request.batch_size)
.setOp(request.op)
.setTableType(request.table_type)
Expand Down Expand Up @@ -298,8 +309,7 @@ private void init(JdbcDataSourceConfig config, String sql) throws JdbcExecutorEx
ClassLoader oldClassLoader = Thread.currentThread().getContextClassLoader();
String hikariDataSourceKey = config.createCacheKey();
try {
ClassLoader parent = getClass().getClassLoader();
ClassLoader classLoader = UdfUtils.getClassLoader(config.getJdbcDriverUrl(), parent);
initializeClassLoader(config);
Thread.currentThread().setContextClassLoader(classLoader);
hikariDataSource = JdbcDataSource.getDataSource().getSource(hikariDataSourceKey);
if (hikariDataSource == null) {
Expand Down Expand Up @@ -357,6 +367,60 @@ private void init(JdbcDataSourceConfig config, String sql) throws JdbcExecutorEx
}
}

private synchronized void initializeClassLoader(JdbcDataSourceConfig config)
throws MalformedURLException, FileNotFoundException {
try {
URL[] urls = {new URL(config.getJdbcDriverUrl())};
if (classLoaderMap.containsKey(urls[0])) {
this.classLoader = classLoaderMap.get(urls[0]);
} else {
String expectedChecksum = config.getJdbcDriverChecksum();
String actualChecksum = computeObjectChecksum(urls[0].toString(), null);
if (!expectedChecksum.equals(actualChecksum)) {
throw new RuntimeException("Checksum mismatch for JDBC driver.");
}
ClassLoader parent = getClass().getClassLoader();
this.classLoader = URLClassLoader.newInstance(urls, parent);
classLoaderMap.put(urls[0], this.classLoader);
}
} catch (MalformedURLException e) {
throw new RuntimeException("Error loading JDBC driver.", e);
}
}

public static String computeObjectChecksum(String urlStr, String encodedAuthInfo) {
try (InputStream inputStream = getInputStreamFromUrl(urlStr, encodedAuthInfo, 10000, 10000)) {
MessageDigest digest = MessageDigest.getInstance("MD5");
byte[] buf = new byte[4096];
int bytesRead;
while ((bytesRead = inputStream.read(buf)) != -1) {
digest.update(buf, 0, bytesRead);
}
return Hex.encodeHexString(digest.digest());
} catch (IOException | NoSuchAlgorithmException e) {
throw new RuntimeException("Compute driver checksum from url: " + urlStr
+ " encountered an error: " + e.getMessage());
}
}

public static InputStream getInputStreamFromUrl(String urlStr, String encodedAuthInfo, int connectTimeoutMs,
int readTimeoutMs) throws IOException {
try {
URL url = new URL(urlStr);
URLConnection conn = url.openConnection();

if (encodedAuthInfo != null) {
conn.setRequestProperty("Authorization", "Basic " + encodedAuthInfo);
}

conn.setConnectTimeout(connectTimeoutMs);
conn.setReadTimeout(readTimeoutMs);
return conn.getInputStream();
} catch (Exception e) {
throw new IOException("Failed to open URL connection: " + urlStr, e);
}
}

protected void setValidationQuery(HikariDataSource ds) {
ds.setConnectionTestQuery("SELECT 1");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class JdbcDataSourceConfig {
private String jdbcPassword;
private String jdbcDriverUrl;
private String jdbcDriverClass;
private String jdbcDriverChecksum;
private int batchSize;
private TJdbcOperation op;
private TOdbcTableType tableType;
Expand Down Expand Up @@ -96,6 +97,15 @@ public JdbcDataSourceConfig setJdbcDriverClass(String jdbcDriverClass) {
return this;
}

public String getJdbcDriverChecksum() {
return jdbcDriverChecksum;
}

public JdbcDataSourceConfig setJdbcDriverChecksum(String jdbcDriverChecksum) {
this.jdbcDriverChecksum = jdbcDriverChecksum;
return this;
}

public int getBatchSize() {
return batchSize;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,12 @@ public static String computeObjectChecksum(String driverPath) throws DdlExceptio
}

public static String getFullDriverUrl(String driverUrl) throws IllegalArgumentException {
if (!(driverUrl.startsWith("file://") || driverUrl.startsWith("http://")
|| driverUrl.startsWith("https://") || driverUrl.matches("^[^:/]+\\.jar$"))) {
throw new IllegalArgumentException("Invalid driver URL format. Supported formats are: "
+ "file://xxx.jar, http://xxx.jar, https://xxx.jar, or xxx.jar (without prefix).");
}

try {
URI uri = new URI(driverUrl);
String schema = uri.getScheme();
Expand Down Expand Up @@ -481,4 +487,3 @@ public static void checkConnectionPoolProperties(int minSize, int maxSize, int m
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.Assertions;

import java.util.Map;

Expand Down Expand Up @@ -216,4 +217,54 @@ public void testJdbcDriverPtah() {
});
Assert.assertEquals("Driver URL does not match any allowed paths: file:///postgresql-42.5.0.jar", exception.getMessage());
}

@Test
public void testValidDriverUrls() {
String fileUrl = "file://path/to/driver.jar";
Assertions.assertDoesNotThrow(() -> {
String result = JdbcResource.getFullDriverUrl(fileUrl);
Assert.assertEquals(fileUrl, result);
});

String httpUrl = "http://example.com/driver.jar";
Assertions.assertDoesNotThrow(() -> {
String result = JdbcResource.getFullDriverUrl(httpUrl);
Assert.assertEquals(httpUrl, result);
});

String httpsUrl = "https://example.com/driver.jar";
Assertions.assertDoesNotThrow(() -> {
String result = JdbcResource.getFullDriverUrl(httpsUrl);
Assert.assertEquals(httpsUrl, result);
});

String jarFile = "driver.jar";
Assertions.assertDoesNotThrow(() -> {
String result = JdbcResource.getFullDriverUrl(jarFile);
Assert.assertTrue(result.startsWith("file://"));
});
}

@Test
public void testInvalidDriverUrls() {
String invalidUrl1 = "/mnt/path/to/driver.jar";
Assert.assertThrows(IllegalArgumentException.class, () -> {
JdbcResource.getFullDriverUrl(invalidUrl1);
});

String invalidUrl2 = "ftp://example.com/driver.jar";
Assert.assertThrows(IllegalArgumentException.class, () -> {
JdbcResource.getFullDriverUrl(invalidUrl2);
});

String invalidUrl3 = "";
Assert.assertThrows(IllegalArgumentException.class, () -> {
JdbcResource.getFullDriverUrl(invalidUrl3);
});

String invalidUrl4 = "example.com/driver";
Assert.assertThrows(IllegalArgumentException.class, () -> {
JdbcResource.getFullDriverUrl(invalidUrl4);
});
}
}
1 change: 1 addition & 0 deletions gensrc/thrift/Types.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ struct TJdbcExecutorCtorParams {
14: optional i32 connection_pool_cache_clear_time
15: optional bool connection_pool_keep_alive
16: optional i64 catalog_id
17: optional string jdbc_driver_checksum
}

struct TJavaUdfExecutorCtorParams {
Expand Down
Loading