diff --git a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala index c9b7493fcdc1b..fa24fb7ea8652 100644 --- a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala @@ -34,4 +34,7 @@ private[spark] class ParentClassLoader(parent: ClassLoader) extends ClassLoader( super.loadClass(name, resolve) } + def loadClass(name: String, b: Array[Byte], off: Int, length: Int): Class[_] = { + super.defineClass(name, b, off, length) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 536276b5cb29f..e52ce130b154a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -188,9 +188,9 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childCode = child.child.genCode(ctx) val input = childCode.value - val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName - val DoublePrefixCmp = classOf[DoublePrefixComparator].getName - val StringPrefixCmp = classOf[StringPrefixComparator].getName + val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getCanonicalName + val DoublePrefixCmp = classOf[DoublePrefixComparator].getCanonicalName + val StringPrefixCmp = classOf[StringPrefixComparator].getCanonicalName val prefixCode = child.child.dataType match { case BooleanType => s"$input ? 1L : 0L" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d5857e060a2c4..4b862a6a05b4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -17,35 +17,29 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import java.io.ByteArrayInputStream -import java.util.{Map => JavaMap} +import java.util.Locale -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.language.existentials -import scala.util.control.NonFatal import com.google.common.cache.{CacheBuilder, CacheLoader} import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} -import org.codehaus.commons.compiler.CompileException -import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, InternalCompilerException, SimpleCompiler} -import org.codehaus.janino.util.ClassFile -import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException} -import org.apache.spark.executor.InputMetrics +import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.expressions.codegen.compiler._ +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types._ -import org.apache.spark.util.{ParentClassLoader, Utils} +import org.apache.spark.util.Utils /** * Java source for evaluating an [[Expression]] given a [[InternalRow]] of input. @@ -133,7 +127,7 @@ class CodegenContext { def addReferenceObj(objName: String, obj: Any, className: String = null): String = { val idx = references.length references += obj - val clsName = Option(className).getOrElse(obj.getClass.getName) + val clsName = Option(className).getOrElse(obj.getClass.getCanonicalName) s"(($clsName) references[$idx] /* $objName */)" } @@ -1170,6 +1164,20 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected val genericMutableRowType: String = classOf[GenericInternalRow].getName + // In janino, Scala.Function1 needs this. But, name crashes happen in JDK compilers + // because of type erasure. Probably, it seems this issue is related to a topic below; + // - https://stackoverflow.com/questions/12206181/generic-class-compiles-in-java-6-but-not-java-7 + // I do not look into this issue, so we need to revisit this. + protected lazy val janinoCompatibilityCode = if (CodeGenerator.janinoCompilerEnabled) { + s""" + |public java.lang.Object apply(java.lang.Object row) { + | return apply((InternalRow) row); + |} + """.stripMargin + } else { + "" + } + /** * Generates a class for a given input expression. Called when there is not cached code * already available. @@ -1229,8 +1237,27 @@ object CodeGenerator extends Logging { // bytecode instruction final val MUTABLESTATEARRAY_SIZE_LIMIT = 32768 + // Make this non-private and non-lazy method for testing + def _compilerImpl(): CompilerBase = { + val compiler = SQLConf.get.javaCompiler + val compilerInstance = compiler match { + case "janino" => JaninoCompiler + case "jdk" => if (JdkCompiler.javaCompiler != null) JdkCompiler else JaninoCompiler + case unknown => throw new IllegalArgumentException(s"Unknown compiler found: $unknown") + } + val compilerName = if (compilerInstance == JaninoCompiler) "Janino" else "JDK" + logInfo(s"$compilerName Java bytecode compiler is used") + compilerInstance + } + + private lazy val compilerImpl = _compilerImpl + + lazy val janinoCompilerEnabled: Boolean = { + compilerImpl == JaninoCompiler + } + /** - * Compile the Java source code into a Java class, using Janino. + * Compile the Java source code into a Java class, using Janino or javac * * @return a pair of a generated class and the max bytecode size of generated functions. */ @@ -1245,113 +1272,10 @@ object CodeGenerator extends Logging { } /** - * Compile the Java source code into a Java class, using Janino. + * Compile the Java source code into a Java class, using Janino or javac */ private[this] def doCompile(code: CodeAndComment): (GeneratedClass, Int) = { - val evaluator = new ClassBodyEvaluator() - - // A special classloader used to wrap the actual parent classloader of - // [[org.codehaus.janino.ClassBodyEvaluator]] (see CodeGenerator.doCompile). This classloader - // does not throw a ClassNotFoundException with a cause set (i.e. exception.getCause returns - // a null). This classloader is needed because janino will throw the exception directly if - // the parent classloader throws a ClassNotFoundException with cause set instead of trying to - // find other possible classes (see org.codehaus.janinoClassLoaderIClassLoader's - // findIClass method). Please also see https://issues.apache.org/jira/browse/SPARK-15622 and - // https://issues.apache.org/jira/browse/SPARK-11636. - val parentClassLoader = new ParentClassLoader(Utils.getContextOrSparkClassLoader) - evaluator.setParentClassLoader(parentClassLoader) - // Cannot be under package codegen, or fail with java.lang.InstantiationException - evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") - evaluator.setDefaultImports( - classOf[Platform].getName, - classOf[InternalRow].getName, - classOf[UnsafeRow].getName, - classOf[UTF8String].getName, - classOf[Decimal].getName, - classOf[CalendarInterval].getName, - classOf[ArrayData].getName, - classOf[UnsafeArrayData].getName, - classOf[MapData].getName, - classOf[UnsafeMapData].getName, - classOf[Expression].getName, - classOf[TaskContext].getName, - classOf[TaskKilledException].getName, - classOf[InputMetrics].getName - ) - evaluator.setExtendedClass(classOf[GeneratedClass]) - - logDebug({ - // Only add extra debugging info to byte code when we are going to print the source code. - evaluator.setDebuggingInformation(true, true, false) - s"\n${CodeFormatter.format(code)}" - }) - - val maxCodeSize = try { - evaluator.cook("generated.java", code.body) - updateAndGetCompilationStats(evaluator) - } catch { - case e: InternalCompilerException => - val msg = s"failed to compile: $e" - logError(msg, e) - val maxLines = SQLConf.get.loggingMaxLinesForCodegen - logInfo(s"\n${CodeFormatter.format(code, maxLines)}") - throw new InternalCompilerException(msg, e) - case e: CompileException => - val msg = s"failed to compile: $e" - logError(msg, e) - val maxLines = SQLConf.get.loggingMaxLinesForCodegen - logInfo(s"\n${CodeFormatter.format(code, maxLines)}") - throw new CompileException(msg, e.getLocation) - } - - (evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass], maxCodeSize) - } - - /** - * Returns the max bytecode size of the generated functions by inspecting janino private fields. - * Also, this method updates the metrics information. - */ - private def updateAndGetCompilationStats(evaluator: ClassBodyEvaluator): Int = { - // First retrieve the generated classes. - val classes = { - val resultField = classOf[SimpleCompiler].getDeclaredField("result") - resultField.setAccessible(true) - val loader = resultField.get(evaluator).asInstanceOf[ByteArrayClassLoader] - val classesField = loader.getClass.getDeclaredField("classes") - classesField.setAccessible(true) - classesField.get(loader).asInstanceOf[JavaMap[String, Array[Byte]]].asScala - } - - // Then walk the classes to get at the method bytecode. - val codeAttr = Utils.classForName("org.codehaus.janino.util.ClassFile$CodeAttribute") - val codeAttrField = codeAttr.getDeclaredField("code") - codeAttrField.setAccessible(true) - val codeSizes = classes.flatMap { case (_, classBytes) => - CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.update(classBytes.length) - try { - val cf = new ClassFile(new ByteArrayInputStream(classBytes)) - val stats = cf.methodInfos.asScala.flatMap { method => - method.getAttributes().filter(_.getClass eq codeAttr).map { a => - val byteCodeSize = codeAttrField.get(a).asInstanceOf[Array[Byte]].length - CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update(byteCodeSize) - - if (byteCodeSize > DEFAULT_JVM_HUGE_METHOD_LIMIT) { - logInfo("Generated method too long to be JIT compiled: " + - s"${cf.getThisClassName}.${method.getName} is $byteCodeSize bytes") - } - - byteCodeSize - } - } - Some(stats) - } catch { - case NonFatal(e) => - logWarning("Error calculating stats of compiled class.", e) - None - } - }.flatten - - codeSizes.max + compilerImpl.compile(code) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d588e7f081303..943d8bceef381 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -127,8 +127,10 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP return (InternalRow) mutableRow; } - public java.lang.Object apply(java.lang.Object _i) { - InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; + $janinoCompatibilityCode + + public InternalRow apply(InternalRow _i) { + InternalRow ${ctx.INPUT_ROW} = _i; $evalSubexpr $allProjections // copy all the results into MutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 39778661d1c48..254eaa05cc2ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -187,8 +187,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${ctx.initPartition()} } - public java.lang.Object apply(java.lang.Object _i) { - InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; + $janinoCompatibilityCode + + public InternalRow apply(InternalRow _i) { + InternalRow ${ctx.INPUT_ROW} = _i; $allExpressions return mutableRow; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 0ecd0de8d8203..b9a765d2b4574 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -355,10 +355,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | ${ctx.initPartition()} | } | - | // Scala.Function1 need this - | public java.lang.Object apply(java.lang.Object row) { - | return apply((InternalRow) row); - | } + | $janinoCompatibilityCode | | public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { | ${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/compiler/CompilerBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/compiler/CompilerBase.scala new file mode 100644 index 0000000000000..8b1ed5ca2c46b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/compiler/CompilerBase.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen.compiler + +import java.io.ByteArrayInputStream + +import scala.collection.JavaConverters._ +import scala.language.existentials +import scala.util.control.NonFatal + +import org.codehaus.janino.util.ClassFile + +import org.apache.spark.{TaskContext, TaskKilledException} +import org.apache.spark.executor.InputMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.metrics.source.CodegenMetrics +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, GeneratedClass} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types.Decimal +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.Utils + + +abstract class CompilerBase extends Logging { + protected val className = "org.apache.spark.sql.catalyst.expressions.GeneratedClass" + + protected val importClassNames = Seq( + classOf[Platform].getName, + classOf[InternalRow].getName, + classOf[UnsafeRow].getName, + classOf[UnsafeProjection].getName, + classOf[UTF8String].getName, + classOf[Decimal].getName, + classOf[CalendarInterval].getName, + classOf[ArrayData].getName, + classOf[UnsafeArrayData].getName, + classOf[MapData].getName, + classOf[UnsafeMapData].getName, + classOf[Expression].getName, + classOf[TaskContext].getName, + classOf[TaskKilledException].getName, + classOf[InputMetrics].getName + ) + + protected val extendedClass = classOf[GeneratedClass] + + protected val debugSource = true + protected val debugLines = true + protected val debugVars = false + + protected def prefixLineNumbers(code: String): String = { + if (!debugLines) { + return code + } + val out = new StringBuilder(code.length * 3 / 2) + var i = 1 + for (line <- code.split("\n")) { + val start = out.length + out.append(i) + i += 1 + val numLength = out.length() - start + out.append(":") + for (spaces <- 0 until 7 - numLength) { + out.append(" ") + } + out.append(line) + out.append('\n') + } + out.toString() + } + + def compile(code: CodeAndComment): (GeneratedClass, Int) + + /** + * Returns the max bytecode size of the generated functions by inspecting janino private fields. + * Also, this method updates the metrics information. + */ + protected def updateAndGetBytecodeSize(byteCodes: Iterable[Array[Byte]]): Int = { + // Walk the classes to get at the method bytecode. + val codeAttr = Utils.classForName("org.codehaus.janino.util.ClassFile$CodeAttribute") + val codeAttrField = codeAttr.getDeclaredField("code") + codeAttrField.setAccessible(true) + val codeSizes = byteCodes.flatMap { byteCode => + CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.update(byteCode.size) + try { + val cf = new ClassFile(new ByteArrayInputStream(byteCode)) + val stats = cf.methodInfos.asScala.flatMap { method => + method.getAttributes().filter(_.getClass.getName == codeAttr.getName).map { a => + val byteCodeSize = codeAttrField.get(a).asInstanceOf[Array[Byte]].length + CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update(byteCodeSize) + byteCodeSize + } + } + Some(stats) + } catch { + case NonFatal(e) => + logWarning("Error calculating stats of compiled class.", e) + None + } + }.flatten + + codeSizes.max + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/compiler/JaninoCompiler.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/compiler/JaninoCompiler.scala new file mode 100644 index 0000000000000..3f3b4f24b624a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/compiler/JaninoCompiler.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen.compiler + +import java.util.{Map => JavaMap} + +import scala.collection.JavaConverters._ + +import org.codehaus.commons.compiler.CompileException +import org.codehaus.janino._ + +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.{ParentClassLoader, Utils} + + +object JaninoCompiler extends CompilerBase { + + private def updateAndGetCompilationStats(evaluator: ClassBodyEvaluator): Int = { + // First retrieve the generated classes. + val classes = { + val resultField = classOf[SimpleCompiler].getDeclaredField("result") + resultField.setAccessible(true) + val loader = resultField.get(evaluator).asInstanceOf[ByteArrayClassLoader] + val classesField = loader.getClass.getDeclaredField("classes") + classesField.setAccessible(true) + classesField.get(loader).asInstanceOf[JavaMap[String, Array[Byte]]].asScala + } + updateAndGetBytecodeSize(classes.values) + } + + override def compile(code: CodeAndComment): (GeneratedClass, Int) = { + val evaluator = new ClassBodyEvaluator() + + // A special classloader used to wrap the actual parent classloader of + // [[org.codehaus.janino.ClassBodyEvaluator]] (see CodeGenerator.doCompile). This classloader + // does not throw a ClassNotFoundException with a cause set (i.e. exception.getCause returns + // a null). This classloader is needed because janino will throw the exception directly if + // the parent classloader throws a ClassNotFoundException with cause set instead of trying to + // find other possible classes (see org.codehaus.janinoClassLoaderIClassLoader's + // findIClass method). Please also see https://issues.apache.org/jira/browse/SPARK-15622 and + // https://issues.apache.org/jira/browse/SPARK-11636. + val parentClassLoader = new ParentClassLoader(Utils.getContextOrSparkClassLoader) + evaluator.setParentClassLoader(parentClassLoader) + // Cannot be under package codegen, or fail with java.lang.InstantiationException + evaluator.setClassName(className) + importClassNames.map(evaluator.setDefaultImports(_)) + evaluator.setExtendedClass(extendedClass) + + logDebug({ + // Only add extra debugging info to byte code when we are going to print the source code. + evaluator.setDebuggingInformation(debugSource, debugLines, debugVars) + s"\n${CodeFormatter.format(code)}" + }) + + try { + evaluator.cook("generated.java", code.body) + } catch { + case e: InternalCompilerException => + val msg = s"failed to compile: $e" + logError(msg, e) + val maxLines = SQLConf.get.loggingMaxLinesForCodegen + logInfo(s"\n${CodeFormatter.format(code, maxLines)}") + throw new InternalCompilerException(msg, e) + case e: CompileException => + val msg = s"failed to compile: $e" + logError(msg, e) + val maxLines = SQLConf.get.loggingMaxLinesForCodegen + logInfo(s"\n${CodeFormatter.format(code, maxLines)}") + throw new CompileException(msg, e.getLocation) + } + + val maxCodeSize = updateAndGetCompilationStats(evaluator) + (evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass], maxCodeSize) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/compiler/JdkCompiler.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/compiler/JdkCompiler.scala new file mode 100644 index 0000000000000..28fc7d4b66b43 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/compiler/JdkCompiler.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen.compiler + +import java.io._ +import java.net.URI +import java.nio.charset.StandardCharsets +import java.util.{Arrays, Locale} +import javax.tools._ +import javax.tools.JavaFileManager.Location +import javax.tools.JavaFileObject.Kind + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.language.existentials + +import org.codehaus.commons.compiler.{CompileException, Location => CompilerLocation} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeFormatter, GeneratedClass} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.ParentClassLoader + + +class JavaCodeManager(fileManager: JavaFileManager) + extends ForwardingJavaFileManager[JavaFileManager](fileManager) with Logging { + + // Holds a map between class names and `JavaCode`s (it has all the inner classes) + val objects = mutable.Map[String, JavaCode]() + + private val classLoader = new ClassLoader(null) { + // Loads a compile class into a current context class loader + val parentLoader = new ParentClassLoader(Thread.currentThread().getContextClassLoader) + + private def loadAllObjects(): Unit = { + objects.foreach { case (className, javaCode) => + try { + parentLoader.loadClass(className) + } catch { + case _: ClassNotFoundException => + val bytecode = javaCode.getBytecode + parentLoader.loadClass(className, bytecode, 0, bytecode.length) + } + } + } + + override def findClass(name: String): Class[_] = { + try { + parentLoader.loadClass(name) + } catch { + case _: ClassNotFoundException => + loadAllObjects() + parentLoader.loadClass(name) + } + } + } + + override def getClassLoader(location: Location): ClassLoader = { + classLoader + } + + override def getJavaFileForOutput( + location: Location, + className: String, + kind: Kind, + sibling: FileObject): JavaFileObject = sibling match { + case code: JavaCode => + logDebug(s"getJavaFileForOutput called: className=$className sibling=${code.className}") + val javaCode = if (code.className != className) JavaCode(className) else code + objects += className -> javaCode + javaCode + case unknown => + throw new CompileException(s"Unknown source file found: $unknown", null) + } +} + +case class JavaCode(className: String, code: Option[String] = None) + extends SimpleJavaFileObject( + URI.create(s"string:///${className.replace('.', '/')}${Kind.SOURCE.extension}"), + Kind.SOURCE) { + + // Holds compiled bytecode + private val outputStream = new ByteArrayOutputStream() + + def getBytecode: Array[Byte] = outputStream.toByteArray + + override def getCharContent(ignoreEncodingErrors: Boolean): CharSequence = code.getOrElse("") + + override def openReader(ignoreEncodingErrors: Boolean): Reader = { + code.map { c => new StringReader(c) }.getOrElse { + throw new CompileException(s"Failed to open a reader for $className", null) + } + } + + override def openOutputStream(): OutputStream = { + outputStream + } +} + +class JDKDiagnosticListener extends DiagnosticListener[JavaFileObject] { + override def report(diagnostic: Diagnostic[_ <: JavaFileObject]): Unit = { + if (diagnostic.getKind == javax.tools.Diagnostic.Kind.ERROR) { + val message = s"$diagnostic (${diagnostic.getCode()})" + val loc = new CompilerLocation( + diagnostic.getSource.toString, + diagnostic.getLineNumber.toShort, + diagnostic.getColumnNumber.toShort + ) + + // Wrap the exception in a RuntimeException, because "report()" + // does not declare checked exceptions. + throw new RuntimeException(new CompileException(message, loc)) + } + } +} + +object JdkCompiler extends CompilerBase { + val javaCompiler = { + ToolProvider.getSystemJavaCompiler + } + + private val compilerOptions = { + val debugOption = new StringBuilder("-g:") + if (this.debugSource) debugOption.append("source,") + if (this.debugLines) debugOption.append("lines,") + if (this.debugVars) debugOption.append("vars,") + if ("-g".equals(debugOption)) { + debugOption.append("none,") + } + val compilerOption = SQLConf.get.jdkCompilerOptions + if (compilerOption != null) { + debugOption.append(compilerOption).append(",") + } + + Arrays.asList("-classpath", System.getProperty("java.class.path"), debugOption.toString()) + } + + private val listener = new JDKDiagnosticListener() + + private def javaCodeManager() = { + val fm = javaCompiler.getStandardFileManager(listener, Locale.ROOT, StandardCharsets.UTF_8) + new JavaCodeManager(fm) + } + + override def compile(code: CodeAndComment): (GeneratedClass, Int) = { + val clazzName = "GeneratedIterator" + + val importClasses = importClassNames.map(name => s"import $name;").mkString("\n") + + val codeWithImports = + s""" + |$importClasses + | + |public class $clazzName extends ${extendedClass.getName} { + |${code.body} + |} + """.stripMargin + + val javaCode = JavaCode(clazzName, Some(codeWithImports)) + val fileManager = javaCodeManager() + val task = javaCompiler.getTask( + null, fileManager, listener, compilerOptions, null, Arrays.asList(javaCode)) + + logDebug({ + s"\n${prefixLineNumbers(CodeFormatter.format(code))}" + }) + + try { + if (!task.call()) { + throw new CompileException("Compilation failed", null) + } + } catch { + case e: RuntimeException => + // Unwrap the compilation exception wrapped at JDKDiagnosticListener and throw it. + val cause = e.getCause + if (cause != null) { + cause.getCause match { + case _: CompileException => throw cause.getCause.asInstanceOf[CompileException] + case _: IOException => throw cause.getCause.asInstanceOf[IOException] + case _ => + } + } + throw e + case _: Throwable => throw new CompileException("Compilation failed", null) + } + + val byteCodes = fileManager.objects.toMap.values.map(_.getBytecode) + val maxMethodBytecodeSize = updateAndGetBytecodeSize(byteCodes) + + val clazz = fileManager.getClassLoader(null).loadClass(clazzName) + (clazz.newInstance().asInstanceOf[GeneratedClass], maxMethodBytecodeSize) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b24d7486f3454..b363bb8feb7c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1551,7 +1551,7 @@ case class ArraysOverlap(left: Expression, right: Expression) val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i) val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) val javaElementClass = CodeGenerator.boxedType(elementType) - val javaSet = classOf[java.util.HashSet[_]].getName + val javaSet = classOf[java.util.HashSet[_]].getCanonicalName val set = ctx.freshName("set") val addToSetFromSmallerCode = nullSafeElementCodegen( smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index da70d7da7351b..bc24dd912792b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1740,6 +1740,10 @@ class SQLConf extends Serializable with Logging { def codegenCacheMaxEntries: Int = getConf(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES) + def javaCompiler: String = getConf(StaticSQLConf.CODEGEN_JAVA_COMPILER) + + def jdkCompilerOptions: String = getConf(StaticSQLConf.CODEGEN_JDK_JAVA_COMPILER_OPTION) + def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index d9c354b165e52..739d5d0b9855d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -82,6 +82,22 @@ object StaticSQLConf { .booleanConf .createWithDefault(false) + val CODEGEN_JAVA_COMPILER = buildStaticConf("spark.sql.codegen.javaCompiler") + .internal() + .doc("Sets the Java bytecode compiler for compiling Java methods for DataFrame or Dataset " + + "program. Acceptable values include: jdk or janino") + .stringConf + .checkValues(Set("jdk", "janino")) + .createWithDefault("jdk") + + val CODEGEN_JDK_JAVA_COMPILER_OPTION = + buildStaticConf("spark.sql.codegen.javaCompiler.jdkOption") + .internal() + .doc("Sets compiler options for JDK Java bytecode compiler. This is ignored when janino is" + + "is selected") + .stringConf + .createWithDefault("") + // When enabling the debug, Spark SQL internal table properties are not filtered out; however, // some related DDL commands (e.g., ANALYZE TABLE and CREATE TABLE LIKE) might not work properly. val DEBUG_MODE = buildStaticConf("spark.sql.debug")