diff --git a/extensions-core/kafka-extraction-namespace/pom.xml b/extensions-core/kafka-extraction-namespace/pom.xml index aa4f563f55c1..a3c4a2d05d3a 100644 --- a/extensions-core/kafka-extraction-namespace/pom.xml +++ b/extensions-core/kafka-extraction-namespace/pom.xml @@ -44,6 +44,7 @@ org.apache.druid.extensions druid-lookups-cached-global ${project.parent.version} + provided org.apache.druid diff --git a/extensions-core/kafka-extraction-namespace/src/main/resources/druid-extension-dependencies.json b/extensions-core/kafka-extraction-namespace/src/main/resources/druid-extension-dependencies.json new file mode 100644 index 000000000000..8b9d1666509a --- /dev/null +++ b/extensions-core/kafka-extraction-namespace/src/main/resources/druid-extension-dependencies.json @@ -0,0 +1,3 @@ +{ + "dependsOnDruidExtensions": ["druid-lookups-cached-global"] +} \ No newline at end of file diff --git a/processing/src/main/java/org/apache/druid/guice/DruidExtensionDependencies.java b/processing/src/main/java/org/apache/druid/guice/DruidExtensionDependencies.java new file mode 100644 index 000000000000..87b587849c83 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/guice/DruidExtensionDependencies.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.guice; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import javax.annotation.Nonnull; + +import java.util.ArrayList; +import java.util.List; + +public class DruidExtensionDependencies +{ + + @JsonProperty("dependsOnDruidExtensions") + private List dependsOnDruidExtensions; + + public DruidExtensionDependencies() + { + this.dependsOnDruidExtensions = new ArrayList<>(); + } + + public DruidExtensionDependencies(@Nonnull final List dependsOnDruidExtensions) + { + this.dependsOnDruidExtensions = dependsOnDruidExtensions; + } + + public List getDependsOnDruidExtensions() + { + return dependsOnDruidExtensions; + } +} diff --git a/processing/src/main/java/org/apache/druid/guice/ExtensionFirstClassLoader.java b/processing/src/main/java/org/apache/druid/guice/ExtensionFirstClassLoader.java index 1a2944b2bdb0..4824333af015 100644 --- a/processing/src/main/java/org/apache/druid/guice/ExtensionFirstClassLoader.java +++ b/processing/src/main/java/org/apache/druid/guice/ExtensionFirstClassLoader.java @@ -24,7 +24,6 @@ import java.io.IOException; import java.net.URL; -import java.net.URLClassLoader; import java.util.ArrayList; import java.util.Enumeration; import java.util.List; @@ -32,13 +31,13 @@ /** * The ClassLoader that gets used when druid.extensions.useExtensionClassloaderFirst = true. */ -public class ExtensionFirstClassLoader extends URLClassLoader +public class ExtensionFirstClassLoader extends StandardURLClassLoader { private final ClassLoader druidLoader; - public ExtensionFirstClassLoader(final URL[] urls, final ClassLoader druidLoader) + public ExtensionFirstClassLoader(final URL[] urls, final ClassLoader druidLoader, final List extensionDependencyClassLoaders) { - super(urls, null); + super(urls, null, extensionDependencyClassLoaders); this.druidLoader = Preconditions.checkNotNull(druidLoader, "druidLoader"); } @@ -60,8 +59,13 @@ protected Class loadClass(final String name, final boolean resolve) throws Cl clazz = findClass(name); } catch (ClassNotFoundException e) { - // Try the Druid classloader. Will throw ClassNotFoundException if the class can't be loaded. - return druidLoader.loadClass(name); + try { + clazz = loadClassFromExtensionDependencies(name); + } + catch (ClassNotFoundException e2) { + // Try the Druid classloader. Will throw ClassNotFoundException if the class can't be loaded. + clazz = druidLoader.loadClass(name); + } } } @@ -76,13 +80,18 @@ protected Class loadClass(final String name, final boolean resolve) throws Cl @Override public URL getResource(final String name) { - final URL resourceFromExtension = super.getResource(name); + URL resourceFromExtension = super.getResource(name); if (resourceFromExtension != null) { return resourceFromExtension; - } else { - return druidLoader.getResource(name); } + + resourceFromExtension = getResourceFromExtensionsDependencies(name); + if (resourceFromExtension != null) { + return resourceFromExtension; + } + + return druidLoader.getResource(name); } @Override @@ -90,6 +99,7 @@ public Enumeration getResources(final String name) throws IOException { final List urls = new ArrayList<>(); Iterators.addAll(urls, Iterators.forEnumeration(super.getResources(name))); + addExtensionResources(name, urls); Iterators.addAll(urls, Iterators.forEnumeration(druidLoader.getResources(name))); return Iterators.asEnumeration(urls.iterator()); } diff --git a/processing/src/main/java/org/apache/druid/guice/ExtensionsLoader.java b/processing/src/main/java/org/apache/druid/guice/ExtensionsLoader.java index 0bdbddfa5a3f..f450f4e8b10f 100644 --- a/processing/src/main/java/org/apache/druid/guice/ExtensionsLoader.java +++ b/processing/src/main/java/org/apache/druid/guice/ExtensionsLoader.java @@ -19,13 +19,19 @@ package org.apache.druid.guice; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.concurrent.GuardedBy; import com.google.inject.Injector; import org.apache.commons.io.FileUtils; import org.apache.druid.initialization.DruidModule; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.RE; +import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.logger.Logger; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import javax.inject.Inject; @@ -34,18 +40,21 @@ import java.io.IOException; import java.net.MalformedURLException; import java.net.URL; -import java.net.URLClassLoader; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.ServiceLoader; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; +import java.util.jar.JarEntry; +import java.util.jar.JarFile; import java.util.stream.Collectors; /** @@ -59,19 +68,28 @@ public class ExtensionsLoader { private static final Logger log = new Logger(ExtensionsLoader.class); - + public static final String DRUID_EXTENSION_DEPENDENCIES_JSON = "druid-extension-dependencies.json"; private final ExtensionsConfig extensionsConfig; - private final ConcurrentHashMap, URLClassLoader> loaders = new ConcurrentHashMap<>(); + private final ObjectMapper objectMapper; + + @GuardedBy("this") + private final HashMap, StandardURLClassLoader> loaders = new HashMap<>(); /** * Map of loaded extensions, keyed by class (or interface). */ - private final ConcurrentHashMap, Collection> extensions = new ConcurrentHashMap<>(); + @GuardedBy("this") + private final HashMap, Collection> extensions = new HashMap<>(); + + @GuardedBy("this") + @MonotonicNonNull + private File[] extensionFilesToLoad; @Inject - public ExtensionsLoader(ExtensionsConfig config) + public ExtensionsLoader(ExtensionsConfig config, ObjectMapper objectMapper) { this.extensionsConfig = config; + this.objectMapper = objectMapper; } public static ExtensionsLoader instance(Injector injector) @@ -92,12 +110,14 @@ public ExtensionsConfig config() */ public Collection getLoadedImplementations(Class clazz) { - @SuppressWarnings("unchecked") - Collection retVal = (Collection) extensions.get(clazz); - if (retVal == null) { - return Collections.emptySet(); + synchronized (this) { + @SuppressWarnings("unchecked") + Collection retVal = (Collection) extensions.get(clazz); + if (retVal == null) { + return Collections.emptySet(); + } + return retVal; } - return retVal; } /** @@ -109,9 +129,11 @@ public Collection getLoadedModules() } @VisibleForTesting - public Map, URLClassLoader> getLoadersMap() + public Map, StandardURLClassLoader> getLoadersMap() { - return loaders; + synchronized (this) { + return loaders; + } } /** @@ -135,12 +157,14 @@ public Collection getFromExtensions(Class serviceClass) // In practice, it appears the only place this matters is with DruidModule: // initialization gets the list of extensions, and two REST API calls later // ask for the same list. - Collection modules = extensions.computeIfAbsent( - serviceClass, - serviceC -> new ServiceLoadingFromExtensions<>(serviceC).implsToLoad - ); - //noinspection unchecked - return (Collection) modules; + synchronized (this) { + Collection modules = extensions.computeIfAbsent( + serviceClass, + serviceC -> new ServiceLoadingFromExtensions<>(serviceC).implsToLoad + ); + //noinspection unchecked + return (Collection) modules; + } } public Collection getModules() @@ -159,7 +183,7 @@ public Collection getModules() * * @return an array of druid extension files that will be loaded by druid process */ - public File[] getExtensionFilesToLoad() + public void initializeExtensionFilesToLoad() { final File rootExtensionsDir = new File(extensionsConfig.getDirectory()); if (rootExtensionsDir.exists() && !rootExtensionsDir.isDirectory()) { @@ -187,25 +211,98 @@ public File[] getExtensionFilesToLoad() extensionsToLoad[i++] = extensionDir; } } - return extensionsToLoad == null ? new File[]{} : extensionsToLoad; + synchronized (this) { + extensionFilesToLoad = extensionsToLoad == null ? new File[]{} : extensionsToLoad; + } + } + + public File[] getExtensionFilesToLoad() + { + synchronized (this) { + if (extensionFilesToLoad == null) { + initializeExtensionFilesToLoad(); + } + return extensionFilesToLoad; + } } /** * @param extension The File instance of the extension we want to load * - * @return a URLClassLoader that loads all the jars on which the extension is dependent + * @return a StandardURLClassLoader that loads all the jars on which the extension is dependent */ - public URLClassLoader getClassLoaderForExtension(File extension, boolean useExtensionClassloaderFirst) + public StandardURLClassLoader getClassLoaderForExtension(File extension, boolean useExtensionClassloaderFirst) + { + return getClassLoaderForExtension(extension, useExtensionClassloaderFirst, new ArrayList<>()); + } + + /** + * @param extension The File instance of the extension we want to load + * @param useExtensionClassloaderFirst Whether to return a StandardURLClassLoader that checks extension classloaders first for classes + * @param extensionDependencyStack If the extension is requested as a dependency of another extension, a list containing the + * dependency stack of the dependent extension (for checking circular dependencies). Otherwise + * this is a empty list. + * @return a StandardURLClassLoader that loads all the jars on which the extension is dependent + */ + public StandardURLClassLoader getClassLoaderForExtension(File extension, boolean useExtensionClassloaderFirst, List extensionDependencyStack) + { + Pair classLoaderKey = Pair.of(extension, useExtensionClassloaderFirst); + synchronized (this) { + StandardURLClassLoader classLoader = loaders.get(classLoaderKey); + if (classLoader == null) { + classLoader = makeClassLoaderWithDruidExtensionDependencies(extension, useExtensionClassloaderFirst, extensionDependencyStack); + loaders.put(classLoaderKey, classLoader); + } + + return classLoader; + } + } + + private StandardURLClassLoader makeClassLoaderWithDruidExtensionDependencies(File extension, boolean useExtensionClassloaderFirst, List extensionDependencyStack) { - return loaders.computeIfAbsent( - Pair.of(extension, useExtensionClassloaderFirst), - k -> makeClassLoaderForExtension(k.lhs, k.rhs) - ); + Optional druidExtensionDependenciesOptional = getDruidExtensionDependencies(extension); + List druidExtensionDependenciesList = druidExtensionDependenciesOptional.isPresent() + ? druidExtensionDependenciesOptional.get().getDependsOnDruidExtensions() + : ImmutableList.of(); + + List extensionDependencyClassLoaders = new ArrayList<>(); + for (String druidExtensionDependencyName : druidExtensionDependenciesList) { + Optional extensionDependencyFileOptional = Arrays.stream(getExtensionFilesToLoad()) + .filter(file -> file.getName().equals(druidExtensionDependencyName)) + .findFirst(); + if (!extensionDependencyFileOptional.isPresent()) { + throw new RE( + StringUtils.format( + "Extension [%s] depends on [%s] which is not a valid extension or not loaded.", + extension.getName(), + druidExtensionDependencyName + ) + ); + } + File extensionDependencyFile = extensionDependencyFileOptional.get(); + if (extensionDependencyStack.contains(extensionDependencyFile.getName())) { + extensionDependencyStack.add(extensionDependencyFile.getName()); + throw new RE( + StringUtils.format( + "Extension [%s] has a circular druid extension dependency. Dependency stack [%s].", + extensionDependencyStack.get(0), + extensionDependencyStack + ) + ); + } + extensionDependencyStack.add(extensionDependencyFile.getName()); + extensionDependencyClassLoaders.add( + getClassLoaderForExtension(extensionDependencyFile, useExtensionClassloaderFirst, extensionDependencyStack) + ); + } + + return makeClassLoaderForExtension(extension, useExtensionClassloaderFirst, extensionDependencyClassLoaders); } - private static URLClassLoader makeClassLoaderForExtension( + private static StandardURLClassLoader makeClassLoaderForExtension( final File extension, - final boolean useExtensionClassloaderFirst + final boolean useExtensionClassloaderFirst, + final List extensionDependencyClassLoaders ) { final Collection jars = FileUtils.listFiles(extension, new String[]{"jar"}, false); @@ -224,9 +321,9 @@ private static URLClassLoader makeClassLoaderForExtension( } if (useExtensionClassloaderFirst) { - return new ExtensionFirstClassLoader(urls, ExtensionsLoader.class.getClassLoader()); + return new ExtensionFirstClassLoader(urls, ExtensionsLoader.class.getClassLoader(), extensionDependencyClassLoaders); } else { - return new URLClassLoader(urls, ExtensionsLoader.class.getClassLoader()); + return new StandardURLClassLoader(urls, ExtensionsLoader.class.getClassLoader(), extensionDependencyClassLoaders); } } @@ -266,6 +363,45 @@ public boolean accept(File dir, String name) } } + private Optional getDruidExtensionDependencies(File extension) + { + final Collection jars = FileUtils.listFiles(extension, new String[]{"jar"}, false); + DruidExtensionDependencies druidExtensionDependencies = null; + String druidExtensionDependenciesJarName = null; + for (File extensionFile : jars) { + try (JarFile jarFile = new JarFile(extensionFile.getPath())) { + Enumeration entries = jarFile.entries(); + + while (entries.hasMoreElements()) { + JarEntry entry = entries.nextElement(); + String entryName = entry.getName(); + if (DRUID_EXTENSION_DEPENDENCIES_JSON.equals(entryName)) { + log.debug("Found extension dependency entry in jar [%s]", extensionFile.getPath()); + if (druidExtensionDependenciesJarName != null) { + throw new RE( + StringUtils.format( + "The extension [%s] has multiple jars [%s] [%s] with dependencies in them. Each jar should be in a separate extension directory.", + extension.getName(), + druidExtensionDependenciesJarName, + jarFile.getName() + ) + ); + } + druidExtensionDependencies = objectMapper.readValue( + jarFile.getInputStream(entry), + DruidExtensionDependencies.class + ); + druidExtensionDependenciesJarName = jarFile.getName(); + } + } + } + catch (IOException e) { + throw new RE(e, "Failed to get dependencies for extension [%s]", extension); + } + } + return druidExtensionDependencies == null ? Optional.empty() : Optional.of(druidExtensionDependencies); + } + private class ServiceLoadingFromExtensions { private final Class serviceClass; @@ -293,17 +429,17 @@ private void addAllFromFileSystem() for (File extension : getExtensionFilesToLoad()) { log.debug("Loading extension [%s] for class [%s]", extension.getName(), serviceClass); try { - final URLClassLoader loader = getClassLoaderForExtension( + final StandardURLClassLoader loader = getClassLoaderForExtension( extension, extensionsConfig.isUseExtensionClassloaderFirst() ); - log.info( - "Loading extension [%s], jars: %s", + "Loading extension [%s], jars: %s. Druid extension dependencies [%s]", extension.getName(), Arrays.stream(loader.getURLs()) .map(u -> new File(u.getPath()).getName()) - .collect(Collectors.joining(", ")) + .collect(Collectors.joining(", ")), + loader.getExtensionDependencyClassLoaders() ); ServiceLoader.load(serviceClass, loader).forEach(impl -> tryAdd(impl, "local file system")); diff --git a/processing/src/main/java/org/apache/druid/guice/StandardURLClassLoader.java b/processing/src/main/java/org/apache/druid/guice/StandardURLClassLoader.java new file mode 100644 index 000000000000..b7b7c0ed741c --- /dev/null +++ b/processing/src/main/java/org/apache/druid/guice/StandardURLClassLoader.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.guice; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterators; + +import java.io.IOException; +import java.net.URL; +import java.net.URLClassLoader; +import java.util.ArrayList; +import java.util.Enumeration; +import java.util.List; + + +/** + * The ClassLoader that gets used when druid.extensions.useExtensionClassloaderFirst = false. + */ +public class StandardURLClassLoader extends URLClassLoader +{ + private final List extensionDependencyClassLoaders; + + public StandardURLClassLoader(final URL[] urls, final ClassLoader druidLoader, final List extensionDependencyClassLoaders) + { + super(urls, druidLoader); + this.extensionDependencyClassLoaders = Preconditions.checkNotNull(extensionDependencyClassLoaders, "extensionDependencyClassLoaders"); + } + + @Override + protected Class loadClass(final String name, final boolean resolve) throws ClassNotFoundException + { + Class clazz; + try { + clazz = super.loadClass(name, resolve); + } + catch (ClassNotFoundException e) { + clazz = loadClassFromExtensionDependencies(name); + } + if (resolve) { + resolveClass(clazz); + } + + return clazz; + } + + @Override + public URL getResource(final String name) + { + URL resource = super.getResource(name); + + if (resource != null) { + return resource; + } + + return getResourceFromExtensionsDependencies(name); + } + + @Override + public Enumeration getResources(final String name) throws IOException + { + final List urls = new ArrayList<>(); + Iterators.addAll(urls, Iterators.forEnumeration(super.getResources(name))); + addExtensionResources(name, urls); + return Iterators.asEnumeration(urls.iterator()); + } + + protected URL getResourceFromExtensionsDependencies(final String name) + { + URL resourceFromExtension = null; + for (ClassLoader classLoader : extensionDependencyClassLoaders) { + resourceFromExtension = classLoader.getResource(name); + if (resourceFromExtension != null) { + break; + } + } + return resourceFromExtension; + } + + protected Class loadClassFromExtensionDependencies(final String name) throws ClassNotFoundException + { + for (ClassLoader classLoader : extensionDependencyClassLoaders) { + try { + return classLoader.loadClass(name); + } + catch (ClassNotFoundException ignored) { + } + } + throw new ClassNotFoundException(); + } + + protected void addExtensionResources(final String name, List urls) throws IOException + { + for (ClassLoader classLoader : extensionDependencyClassLoaders) { + Iterators.addAll(urls, Iterators.forEnumeration(classLoader.getResources(name))); + } + } + + public List getExtensionDependencyClassLoaders() + { + return extensionDependencyClassLoaders; + } +} diff --git a/processing/src/test/java/org/apache/druid/guice/ExtensionsLoaderTest.java b/processing/src/test/java/org/apache/druid/guice/ExtensionsLoaderTest.java index 44b7f06fb3d8..2e9be1741007 100644 --- a/processing/src/test/java/org/apache/druid/guice/ExtensionsLoaderTest.java +++ b/processing/src/test/java/org/apache/druid/guice/ExtensionsLoaderTest.java @@ -21,33 +21,46 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; import com.google.inject.Injector; import org.apache.druid.initialization.DruidModule; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.RE; +import org.apache.druid.java.util.common.StringUtils; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import java.io.File; +import java.io.FileOutputStream; import java.io.IOException; import java.net.URL; -import java.net.URLClassLoader; +import java.nio.charset.Charset; import java.util.Arrays; import java.util.Collection; import java.util.Comparator; import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; import java.util.Set; +import java.util.jar.JarEntry; +import java.util.jar.JarOutputStream; public class ExtensionsLoaderTest { @Rule public final TemporaryFolder temporaryFolder = new TemporaryFolder(); + private final ObjectMapper objectMapper = new ObjectMapper(); + private final Map jarFileContents = ImmutableMap.of( + "jar-resource", + "jar-resource-contents".getBytes(Charset.defaultCharset()) + ); + private Injector startupInjector() { return new StartupInjectorBuilder() @@ -76,7 +89,7 @@ public void test04DuplicateClassLoaderExtensions() throws Exception Pair key = Pair.of(extensionDir, true); extnLoader.getLoadersMap() - .put(key, new URLClassLoader(new URL[]{}, ExtensionsLoader.class.getClassLoader())); + .put(key, new StandardURLClassLoader(new URL[]{}, ExtensionsLoader.class.getClassLoader(), ImmutableList.of())); Collection modules = extnLoader.getFromExtensions(DruidModule.class); @@ -90,16 +103,18 @@ public void test04DuplicateClassLoaderExtensions() throws Exception @Test public void test06GetClassLoaderForExtension() throws IOException { - final ExtensionsLoader extnLoader = new ExtensionsLoader(new ExtensionsConfig()); + final ExtensionsLoader extnLoader = new ExtensionsLoader(new ExtensionsConfig(), objectMapper); final File some_extension_dir = temporaryFolder.newFolder(); final File a_jar = new File(some_extension_dir, "a.jar"); final File b_jar = new File(some_extension_dir, "b.jar"); final File c_jar = new File(some_extension_dir, "c.jar"); - a_jar.createNewFile(); - b_jar.createNewFile(); - c_jar.createNewFile(); - final URLClassLoader loader = extnLoader.getClassLoaderForExtension(some_extension_dir, false); + createNewJar(a_jar, jarFileContents); + createNewJar(b_jar, jarFileContents); + createNewJar(c_jar, jarFileContents); + + + final StandardURLClassLoader loader = extnLoader.getClassLoaderForExtension(some_extension_dir, false); final URL[] expectedURLs = new URL[]{a_jar.toURI().toURL(), b_jar.toURI().toURL(), c_jar.toURI().toURL()}; final URL[] actualURLs = loader.getURLs(); Arrays.sort(actualURLs, Comparator.comparing(URL::getPath)); @@ -109,7 +124,7 @@ public void test06GetClassLoaderForExtension() throws IOException @Test public void testGetLoadedModules() { - final ExtensionsLoader extnLoader = new ExtensionsLoader(new ExtensionsConfig()); + final ExtensionsLoader extnLoader = new ExtensionsLoader(new ExtensionsConfig(), objectMapper); Collection modules = extnLoader.getModules(); HashSet moduleSet = new HashSet<>(modules); @@ -134,7 +149,7 @@ public String getDirectory() { return tmpDir.getAbsolutePath(); } - }); + }, objectMapper); Assert.assertArrayEquals( "Non-exist root extensionsDir should return an empty array of File", new File[]{}, @@ -155,7 +170,7 @@ public String getDirectory() return extensionsDir.getAbsolutePath(); } }; - final ExtensionsLoader extnLoader = new ExtensionsLoader(config); + final ExtensionsLoader extnLoader = new ExtensionsLoader(config, objectMapper); extnLoader.getExtensionFilesToLoad(); } @@ -172,7 +187,7 @@ public String getDirectory() } }; - final ExtensionsLoader extnLoader = new ExtensionsLoader(config); + final ExtensionsLoader extnLoader = new ExtensionsLoader(config, objectMapper); Assert.assertArrayEquals( "Empty root extensionsDir should return an empty array of File", new File[]{}, @@ -196,7 +211,7 @@ public String getDirectory() return extensionsDir.getAbsolutePath(); } }; - final ExtensionsLoader extnLoader = new ExtensionsLoader(config); + final ExtensionsLoader extnLoader = new ExtensionsLoader(config, objectMapper); final File mysql_metadata_storage = new File(extensionsDir, "mysql-metadata-storage"); mysql_metadata_storage.mkdir(); @@ -231,7 +246,7 @@ public String getDirectory() return extensionsDir.getAbsolutePath(); } }; - final ExtensionsLoader extnLoader = new ExtensionsLoader(config); + final ExtensionsLoader extnLoader = new ExtensionsLoader(config, objectMapper); final File mysql_metadata_storage = new File(extensionsDir, "mysql-metadata-storage"); final File random_extension = new File(extensionsDir, "random-extensions"); @@ -267,7 +282,7 @@ public String getDirectory() }; final File random_extension = new File(extensionsDir, "random-extensions"); random_extension.mkdir(); - final ExtensionsLoader extnLoader = new ExtensionsLoader(config); + final ExtensionsLoader extnLoader = new ExtensionsLoader(config, objectMapper); extnLoader.getExtensionFilesToLoad(); } @@ -320,14 +335,139 @@ public void testExtensionsWithSameDirName() throws Exception final File jar1 = new File(extension1, "jar1.jar"); final File jar2 = new File(extension2, "jar2.jar"); - Assert.assertTrue(jar1.createNewFile()); - Assert.assertTrue(jar2.createNewFile()); + createNewJar(jar1, jarFileContents); + createNewJar(jar2, jarFileContents); - final ExtensionsLoader extnLoader = new ExtensionsLoader(new ExtensionsConfig()); + final ExtensionsLoader extnLoader = new ExtensionsLoader(new ExtensionsConfig(), objectMapper); final ClassLoader classLoader1 = extnLoader.getClassLoaderForExtension(extension1, false); final ClassLoader classLoader2 = extnLoader.getClassLoaderForExtension(extension2, false); - Assert.assertArrayEquals(new URL[]{jar1.toURI().toURL()}, ((URLClassLoader) classLoader1).getURLs()); - Assert.assertArrayEquals(new URL[]{jar2.toURI().toURL()}, ((URLClassLoader) classLoader2).getURLs()); + Assert.assertArrayEquals(new URL[]{jar1.toURI().toURL()}, ((StandardURLClassLoader) classLoader1).getURLs()); + Assert.assertArrayEquals(new URL[]{jar2.toURI().toURL()}, ((StandardURLClassLoader) classLoader2).getURLs()); + } + + @Test + public void testGetClassLoaderForExtension_withMissingDependency() throws IOException + { + final ExtensionsLoader extnLoader = new ExtensionsLoader(new ExtensionsConfig(), objectMapper); + final String druidExtensionDependency = "other-druid-extension"; + final DruidExtensionDependencies druidExtensionDependencies = new DruidExtensionDependencies(ImmutableList.of(druidExtensionDependency)); + + final File extensionDir = temporaryFolder.newFolder(); + final File extensionJar = new File(extensionDir, "a.jar"); + createNewJar(extensionJar, ImmutableMap.of("druid-extension-dependencies.json", objectMapper.writeValueAsBytes(druidExtensionDependencies))); + + RE exception = Assert.assertThrows(RE.class, () -> { + extnLoader.getClassLoaderForExtension(extensionDir, false); + }); + + Assert.assertEquals( + StringUtils.format("Extension [%s] depends on [%s] which is not a valid extension or not loaded.", extensionDir.getName(), druidExtensionDependency), + exception.getMessage() + ); + } + + @Test + public void testGetClassLoaderForExtension_dependencyLoaded() throws IOException + { + ExtensionsConfig extensionsConfig = new TestExtensionsConfig(temporaryFolder.getRoot().getPath()); + final ExtensionsLoader extnLoader = new ExtensionsLoader(extensionsConfig, objectMapper); + + final File extensionDir = temporaryFolder.newFolder(); + final File extensionJar = new File(extensionDir, "a.jar"); + createNewJar(extensionJar, jarFileContents); + + final File dependentExtensionDir = temporaryFolder.newFolder(); + final File dependentExtensionJar = new File(dependentExtensionDir, "a.jar"); + final DruidExtensionDependencies druidExtensionDependencies = new DruidExtensionDependencies(ImmutableList.of(extensionDir.getName())); + createNewJar(dependentExtensionJar, ImmutableMap.of("druid-extension-dependencies.json", objectMapper.writeValueAsBytes(druidExtensionDependencies))); + + StandardURLClassLoader classLoader = extnLoader.getClassLoaderForExtension(extensionDir, false); + StandardURLClassLoader dependendentClassLoader = extnLoader.getClassLoaderForExtension(dependentExtensionDir, false); + Assert.assertTrue(dependendentClassLoader.getExtensionDependencyClassLoaders().contains(classLoader)); + Assert.assertEquals(0, classLoader.getExtensionDependencyClassLoaders().size()); + + } + + @Test + public void testGetClassLoaderForExtension_circularDependency() throws IOException + { + ExtensionsConfig extensionsConfig = new TestExtensionsConfig(temporaryFolder.getRoot().getPath()); + final ExtensionsLoader extnLoader = new ExtensionsLoader(extensionsConfig, objectMapper); + + final File extensionDir = temporaryFolder.newFolder(); + final File dependentExtensionDir = temporaryFolder.newFolder(); + + final File extensionJar = new File(extensionDir, "a.jar"); + final DruidExtensionDependencies druidExtensionDependencies = new DruidExtensionDependencies(ImmutableList.of(dependentExtensionDir.getName())); + createNewJar(extensionJar, ImmutableMap.of("druid-extension-dependencies.json", objectMapper.writeValueAsBytes(druidExtensionDependencies))); + + final File dependentExtensionJar = new File(dependentExtensionDir, "a.jar"); + final DruidExtensionDependencies druidExtensionDependenciesCircular = new DruidExtensionDependencies(ImmutableList.of(extensionDir.getName())); + createNewJar(dependentExtensionJar, ImmutableMap.of("druid-extension-dependencies.json", objectMapper.writeValueAsBytes(druidExtensionDependenciesCircular))); + + RE exception = Assert.assertThrows(RE.class, () -> { + extnLoader.getClassLoaderForExtension(extensionDir, false); + }); + + Assert.assertTrue(exception.getMessage().contains("has a circular druid extension dependency.")); + } + + @Test + public void testGetClassLoaderForExtension_multipleDruidJars() throws IOException + { + ExtensionsConfig extensionsConfig = new TestExtensionsConfig(temporaryFolder.getRoot().getPath()); + final ExtensionsLoader extnLoader = new ExtensionsLoader(extensionsConfig, objectMapper); + + final File extensionDir = temporaryFolder.newFolder(); + + final File extensionJar = new File(extensionDir, "a.jar"); + final DruidExtensionDependencies druidExtensionDependencies = new DruidExtensionDependencies(ImmutableList.of()); + createNewJar(extensionJar, ImmutableMap.of("druid-extension-dependencies.json", objectMapper.writeValueAsBytes(druidExtensionDependencies))); + + final File extensionJar2 = new File(extensionDir, "b.jar"); + createNewJar(extensionJar2, ImmutableMap.of("druid-extension-dependencies.json", objectMapper.writeValueAsBytes(druidExtensionDependencies))); + + + RE exception = Assert.assertThrows(RE.class, () -> { + extnLoader.getClassLoaderForExtension(extensionDir, false); + }); + + Assert.assertTrue( + exception.getMessage().contains("Each jar should be in a separate extension directory.") + ); + } + + + + private void createNewJar(File jarFileLocation, Map jarFileContents) throws IOException + { + Assert.assertTrue(jarFileLocation.createNewFile()); + FileOutputStream fos = new FileOutputStream(jarFileLocation.getPath()); + JarOutputStream jarOut = new JarOutputStream(fos); + for (Map.Entry fileNameToContents : jarFileContents.entrySet()) { + JarEntry entry = new JarEntry(fileNameToContents.getKey()); + jarOut.putNextEntry(entry); + jarOut.write(fileNameToContents.getValue()); + jarOut.closeEntry(); + } + jarOut.close(); + fos.close(); + } + + private static class TestExtensionsConfig extends ExtensionsConfig + { + final String directory; + + public TestExtensionsConfig(String directory) + { + this.directory = directory; + } + + @Override + public String getDirectory() + { + return directory; + } } }