diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 53b74c7c85594..2891dd4694a94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -21,6 +21,7 @@ import java.util.Collections import scala.collection.JavaConverters._ +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ @@ -33,6 +34,7 @@ import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat import org.apache.spark.sql.execution.streaming.{StreamExecution, StreamingQueryWrapper} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQuery +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{AccumulatorV2, LongAccumulator} /** @@ -255,5 +257,13 @@ package object debug { override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { consume(ctx, input) } + + override def doExecuteBroadcast[T](): Broadcast[T] = { + child.executeBroadcast() + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + child.executeColumnar() + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 8251ff159e05f..e423420c2914a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.debug +import java.io.ByteArrayOutputStream + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -48,4 +50,45 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { assert(res.forall{ case (subtree, code) => subtree.contains("Range") && code.contains("Object[]")}) } + + test("SPARK-28537: DebugExec cannot debug broadcast related queries") { + val rightDF = spark.range(10) + val leftDF = spark.range(10) + val joinedDF = leftDF.join(rightDF, leftDF("id") === rightDF("id")) + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + joinedDF.debug() + } + + val output = captured.toString() + assert(output.contains( + """== BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, false])) == + |Tuples output: 0 + | id LongType: {} + |== WholeStageCodegen == + |Tuples output: 10 + | id LongType: {java.lang.Long} + |== Range (0, 10, step=1, splits=2) == + |Tuples output: 0 + | id LongType: {}""".stripMargin)) + } + + test("SPARK-28537: DebugExec cannot debug columnar related queries") { + val df = spark.range(5) + df.persist() + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + df.debug() + } + df.unpersist() + + val output = captured.toString().replaceAll("#\\d+", "#x") + assert(output.contains( + """== InMemoryTableScan [id#xL] == + |Tuples output: 0 + | id LongType: {} + |""".stripMargin)) + } }