diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index e6c312abf1698..3938d92ba36f1 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -95,6 +95,8 @@ license: | - In Spark 3.2, `FloatType` is mapped to `FLOAT` in MySQL. Prior to this, it used to be mapped to `REAL`, which is by default a synonym to `DOUBLE PRECISION` in MySQL. + - In Spark 3.2, the query executions triggered by `DataFrameWriter` are always named `command` when being sent to `QueryExecutionListener`. In Spark 3.1 and earlier, the name is one of `save`, `insertInto`, `saveAsTable`, `create`, `append`, `overwrite`, `overwritePartitions`, `replace`. + ## Upgrading from Spark SQL 3.0 to 3.1 - In Spark 3.1, statistical aggregation function includes `std`, `stddev`, `stddev_samp`, `variance`, `var_samp`, `skewness`, `kurtosis`, `covar_samp`, `corr` will return `NULL` instead of `Double.NaN` when `DivideByZero` occurs during expression evaluation, for example, when `stddev_samp` applied on a single element set. In Spark version 3.0 and earlier, it will return `Double.NaN` in such case. To restore the behavior before Spark 3.1, you can set `spark.sql.legacy.statisticalAggregate` to `true`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index c96fa9e4f903f..5b68493ae132a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Implicits, CatalogV2Util, Identifier, SupportsCatalogOptions, Table, TableCatalog, TableProvider, V1Table} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} -import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2._ @@ -311,13 +310,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val relation = DataSourceV2Relation.create(table, catalog, ident, dsOptions) checkPartitioningMatchesV2Table(table) if (mode == SaveMode.Append) { - runCommand(df.sparkSession, "save") { + runCommand(df.sparkSession) { AppendData.byName(relation, df.logicalPlan, finalOptions) } } else { // Truncate the table. TableCapabilityCheck will throw a nice exception if this // isn't supported - runCommand(df.sparkSession, "save") { + runCommand(df.sparkSession) { OverwriteByExpression.byName( relation, df.logicalPlan, Literal(true), finalOptions) } @@ -332,7 +331,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val location = Option(dsOptions.get("path")).map(TableCatalog.PROP_LOCATION -> _) - runCommand(df.sparkSession, "save") { + runCommand(df.sparkSession) { CreateTableAsSelect( catalog, ident, @@ -379,7 +378,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val optionsWithPath = getOptionsWithPath(path) // Code path for data source v1. - runCommand(df.sparkSession, "save") { + runCommand(df.sparkSession) { DataSource( sparkSession = df.sparkSession, className = source, @@ -475,13 +474,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } - runCommand(df.sparkSession, "insertInto") { + runCommand(df.sparkSession) { command } } private def insertInto(tableIdent: TableIdentifier): Unit = { - runCommand(df.sparkSession, "insertInto") { + runCommand(df.sparkSession) { InsertIntoStatement( table = UnresolvedRelation(tableIdent), partitionSpec = Map.empty[String, Option[String]], @@ -631,7 +630,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { external = false) } - runCommand(df.sparkSession, "saveAsTable") { + runCommand(df.sparkSession) { command } } @@ -698,7 +697,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitionColumnNames = partitioningColumns.getOrElse(Nil), bucketSpec = getBucketSpec) - runCommand(df.sparkSession, "saveAsTable")( + runCommand(df.sparkSession)( CreateTable(tableDesc, mode, Some(df.logicalPlan))) } @@ -856,10 +855,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * Wrap a DataFrameWriter action to track the QueryExecution and time cost, then report to the * user-registered callback functions. */ - private def runCommand(session: SparkSession, name: String)(command: LogicalPlan): Unit = { + private def runCommand(session: SparkSession)(command: LogicalPlan): Unit = { val qe = session.sessionState.executePlan(command) - // call `QueryExecution.toRDD` to trigger the execution of commands. - SQLExecution.withNewExecutionId(qe, Some(name))(qe.toRdd) + // call `QueryExecution.commandExecuted` to trigger the execution of commands. + qe.commandExecuted } private def lookupV2Provider(): Option[TableProvider] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index 9a49fc3d74780..7b131058db4da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -107,7 +107,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) } override def create(): Unit = { - runCommand("create") { + runCommand( CreateTableAsSelectStatement( tableName, logicalPlan, @@ -121,8 +121,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) options.toMap, None, ifNotExists = false, - external = false) - } + external = false)) } override def replace(): Unit = { @@ -146,7 +145,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) @throws(classOf[NoSuchTableException]) def append(): Unit = { val append = AppendData.byName(UnresolvedRelation(tableName), logicalPlan, options.toMap) - runCommand("append")(append) + runCommand(append) } /** @@ -163,7 +162,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) def overwrite(condition: Column): Unit = { val overwrite = OverwriteByExpression.byName( UnresolvedRelation(tableName), logicalPlan, condition.expr, options.toMap) - runCommand("overwrite")(overwrite) + runCommand(overwrite) } /** @@ -183,21 +182,21 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) def overwritePartitions(): Unit = { val dynamicOverwrite = OverwritePartitionsDynamic.byName( UnresolvedRelation(tableName), logicalPlan, options.toMap) - runCommand("overwritePartitions")(dynamicOverwrite) + runCommand(dynamicOverwrite) } /** * Wrap an action to track the QueryExecution and time cost, then report to the user-registered * callback functions. */ - private def runCommand(name: String)(command: LogicalPlan): Unit = { + private def runCommand(command: LogicalPlan): Unit = { val qe = sparkSession.sessionState.executePlan(command) // call `QueryExecution.toRDD` to trigger the execution of commands. - SQLExecution.withNewExecutionId(qe, Some(name))(qe.toRdd) + SQLExecution.withNewExecutionId(qe, Some("command"))(qe.toRdd) } private def internalReplace(orCreate: Boolean): Unit = { - runCommand("replace") { + runCommand( ReplaceTableAsSelectStatement( tableName, logicalPlan, @@ -210,8 +209,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) None, options.toMap, None, - orCreate = orCreate) - } + orCreate = orCreate)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 89868c55e2965..a0f7bd2d640b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -221,16 +221,7 @@ class Dataset[T] private[sql]( } @transient private[sql] val logicalPlan: LogicalPlan = { - // For various commands (like DDL) and queries with side effects, we force query execution - // to happen right away to let these side effects take place eagerly. - val plan = queryExecution.analyzed match { - case c: Command => - LocalRelation(c.output, withAction("command", queryExecution)(_.executeCollect())) - case u @ Union(children, _, _) if children.forall(_.isInstanceOf[Command]) => - LocalRelation(u.output, withAction("command", queryExecution)(_.executeCollect())) - case _ => - queryExecution.analyzed - } + val plan = queryExecution.commandExecuted if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) { val dsIds = plan.getTagValue(Dataset.DATASET_ID_TAG).getOrElse(new HashSet[Long]) dsIds.add(id) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CommandResultExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CommandResultExec.scala new file mode 100644 index 0000000000000..c4f4b04a6440b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CommandResultExec.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.execution.metric.SQLMetrics + +/** + * Physical plan node for holding data from a command. + * + * `commandPhysicalPlan` is just used to display the plan tree for EXPLAIN. + * `rows` may not be serializable and ideally we should not send `rows` to the executors. + * Thus marking them as transient. + */ +case class CommandResultExec( + output: Seq[Attribute], + @transient commandPhysicalPlan: SparkPlan, + @transient rows: Seq[InternalRow]) extends LeafExecNode with InputRDDCodegen { + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + + override def innerChildren: Seq[QueryPlan[_]] = Seq(commandPhysicalPlan) + + @transient private lazy val unsafeRows: Array[InternalRow] = { + if (rows.isEmpty) { + Array.empty + } else { + val proj = UnsafeProjection.create(output, output) + rows.map(r => proj(r).copy()).toArray + } + } + + @transient private lazy val rdd: RDD[InternalRow] = { + if (rows.isEmpty) { + sqlContext.sparkContext.emptyRDD + } else { + val numSlices = math.min( + unsafeRows.length, sqlContext.sparkSession.leafNodeDefaultParallelism) + sqlContext.sparkContext.parallelize(unsafeRows, numSlices) + } + } + + override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + rdd.map { r => + numOutputRows += 1 + r + } + } + + override protected def stringArgs: Iterator[Any] = { + if (unsafeRows.isEmpty) { + Iterator("", output) + } else { + Iterator(output) + } + } + + override def executeCollect(): Array[InternalRow] = { + longMetric("numOutputRows").add(rows.size) + rows.toArray + } + + override def executeTake(limit: Int): Array[InternalRow] = { + val taken = unsafeRows.take(limit) + longMetric("numOutputRows").add(taken.size) + taken + } + + override def executeTail(limit: Int): Array[InternalRow] = { + val taken: Seq[InternalRow] = unsafeRows.takeRight(limit) + longMetric("numOutputRows").add(taken.size) + taken.toArray + } + + // Input is already UnsafeRows. + override protected val createUnsafeProjection: Boolean = false + + override def inputRDD: RDD[InternalRow] = rdd +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index 5e706498b444d..a13abdc9a3df3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -44,36 +44,42 @@ object HiveResult { TimeFormatters(dateFormatter, timestampFormatter) } + private def stripRootCommandResult(executedPlan: SparkPlan): SparkPlan = executedPlan match { + case CommandResultExec(_, plan, _) => plan + case other => other + } + /** * Returns the result as a hive compatible sequence of strings. This is used in tests and * `SparkSQLDriver` for CLI applications. */ - def hiveResultString(executedPlan: SparkPlan): Seq[String] = executedPlan match { - case ExecutedCommandExec(_: DescribeCommandBase) => - formatDescribeTableOutput(executedPlan.executeCollectPublic()) - case _: DescribeTableExec => - formatDescribeTableOutput(executedPlan.executeCollectPublic()) - // SHOW TABLES in Hive only output table names while our v1 command outputs - // database, table name, isTemp. - case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended => - command.executeCollect().map(_.getString(1)) - // SHOW TABLES in Hive only output table names while our v2 command outputs - // namespace and table name. - case command : ShowTablesExec => - command.executeCollect().map(_.getString(1)) - // SHOW VIEWS in Hive only outputs view names while our v1 command outputs - // namespace, viewName, and isTemporary. - case command @ ExecutedCommandExec(_: ShowViewsCommand) => - command.executeCollect().map(_.getString(1)) - case other => - val timeFormatters = getTimeFormatters - val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq - // We need the types so we can output struct field names - val types = executedPlan.output.map(_.dataType) - // Reformat to match hive tab delimited output. - result.map(_.zip(types).map(e => toHiveString(e, false, timeFormatters))) - .map(_.mkString("\t")) - } + def hiveResultString(executedPlan: SparkPlan): Seq[String] = + stripRootCommandResult(executedPlan) match { + case ExecutedCommandExec(_: DescribeCommandBase) => + formatDescribeTableOutput(executedPlan.executeCollectPublic()) + case _: DescribeTableExec => + formatDescribeTableOutput(executedPlan.executeCollectPublic()) + // SHOW TABLES in Hive only output table names while our v1 command outputs + // database, table name, isTemp. + case ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended => + executedPlan.executeCollect().map(_.getString(1)) + // SHOW TABLES in Hive only output table names while our v2 command outputs + // namespace and table name. + case _ : ShowTablesExec => + executedPlan.executeCollect().map(_.getString(1)) + // SHOW VIEWS in Hive only outputs view names while our v1 command outputs + // namespace, viewName, and isTemporary. + case ExecutedCommandExec(_: ShowViewsCommand) => + executedPlan.executeCollect().map(_.getString(1)) + case other => + val timeFormatters = getTimeFormatters + val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq + // We need the types so we can output struct field names + val types = executedPlan.output.map(_.dataType) + // Reformat to match hive tab delimited output. + result.map(_.zip(types).map(e => toHiveString(e, false, timeFormatters))) + .map(_.mkString("\t")) + } private def formatDescribeTableOutput(rows: Array[Row]): Seq[String] = { rows.map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 247baeaeb510e..a794a47ecb57f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} +import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableU import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata} +import org.apache.spark.sql.expressions.CommandResult import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.util.Utils @@ -53,7 +54,8 @@ import org.apache.spark.util.Utils class QueryExecution( val sparkSession: SparkSession, val logical: LogicalPlan, - val tracker: QueryPlanningTracker = new QueryPlanningTracker) extends Logging { + val tracker: QueryPlanningTracker = new QueryPlanningTracker, + val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL) extends Logging { val id: Long = QueryExecution.nextExecutionId @@ -73,23 +75,51 @@ class QueryExecution( sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker) } + lazy val commandExecuted: LogicalPlan = mode match { + case CommandExecutionMode.NON_ROOT => analyzed.mapChildren(eagerlyExecuteCommands) + case CommandExecutionMode.ALL => eagerlyExecuteCommands(analyzed) + case CommandExecutionMode.SKIP => analyzed + } + + private def eagerlyExecuteCommands(p: LogicalPlan) = p transformDown { + case c: Command => + val qe = sparkSession.sessionState.executePlan(c, CommandExecutionMode.NON_ROOT) + val result = + SQLExecution.withNewExecutionId(qe, Some("command"))(qe.executedPlan.executeCollect()) + CommandResult( + qe.analyzed.output, + qe.commandExecuted, + qe.executedPlan, + result) + case other => other + } + lazy val withCachedData: LogicalPlan = sparkSession.withActive { assertAnalyzed() assertSupported() // clone the plan to avoid sharing the plan instance between different stages like analyzing, // optimizing and planning. - sparkSession.sharedState.cacheManager.useCachedData(analyzed.clone()) + sparkSession.sharedState.cacheManager.useCachedData(commandExecuted.clone()) } - lazy val optimizedPlan: LogicalPlan = executePhase(QueryPlanningTracker.OPTIMIZATION) { - // clone the plan to avoid sharing the plan instance between different stages like analyzing, - // optimizing and planning. - val plan = sparkSession.sessionState.optimizer.executeAndTrack(withCachedData.clone(), tracker) - // We do not want optimized plans to be re-analyzed as literals that have been constant folded - // and such can cause issues during analysis. While `clone` should maintain the `analyzed` state - // of the LogicalPlan, we set the plan as analyzed here as well out of paranoia. - plan.setAnalyzed() - plan + private def assertCommandExecuted(): Unit = commandExecuted + + lazy val optimizedPlan: LogicalPlan = { + // We need to materialize the commandExecuted here because optimizedPlan is also tracked under + // the optimizing phase + assertCommandExecuted() + executePhase(QueryPlanningTracker.OPTIMIZATION) { + // clone the plan to avoid sharing the plan instance between different stages like analyzing, + // optimizing and planning. + val plan = + sparkSession.sessionState.optimizer.executeAndTrack(withCachedData.clone(), tracker) + // We do not want optimized plans to be re-analyzed as literals that have been constant + // folded and such can cause issues during analysis. While `clone` should maintain the + // `analyzed` state of the LogicalPlan, we set the plan as analyzed here as well out of + // paranoia. + plan.setAnalyzed() + plan + } } private def assertOptimized(): Unit = optimizedPlan @@ -333,6 +363,19 @@ class QueryExecution( } } +/** + * SPARK-35378: Commands should be executed eagerly so that something like `sql("INSERT ...")` + * can trigger the table insertion immediately without a `.collect()`. To avoid end-less recursion + * we should use `NON_ROOT` when recursively executing commands. Note that we can't execute + * a query plan with leaf command nodes, because many commands return `GenericInternalRow` + * and can't be put in a query plan directly, otherwise the query engine may cast + * `GenericInternalRow` to `UnsafeRow` and fail. When running EXPLAIN, or commands inside other + * command, we should use `SKIP` to not eagerly trigger the command execution. + */ +object CommandExecutionMode extends Enumeration { + val SKIP, NON_ROOT, ALL = Value +} + object QueryExecution { private val _nextExecutionId = new AtomicLong(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 45d4c7d7e08f2..858df124f78df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.exchange.{REPARTITION, REPARTITION_WITH_NU import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlan +import org.apache.spark.sql.expressions.CommandResult import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -697,6 +698,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data, _) => LocalTableScanExec(output, data) :: Nil + case CommandResult(output, _, plan, data) => CommandResultExec(output, plan, data) :: Nil case logical.LocalLimit(IntegerLiteral(limit), child) => execution.LocalLimitExec(limit, planLater(child)) :: Nil case logical.GlobalLimit(IntegerLiteral(limit), child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 85bc98d194fee..9e1dbaf5e99ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -934,6 +934,10 @@ case class CollapseCodegenStages( // Do not make LogicalTableScanExec the root of WholeStageCodegen // to support the fast driver-local collect/take paths. plan + case plan: CommandResultExec => + // Do not make CommandResultExec the root of WholeStageCodegen + // to support the fast driver-local collect/take paths. + plan case plan: CodegenSupport if supportCodegen(plan) => // The whole-stage-codegen framework is row-based. If a plan supports columnar execution, // it can't support whole-stage-codegen at the same time. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index 46612d905554d..2bcfa1b108f1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -45,6 +45,7 @@ case class InsertAdaptiveSparkPlan( private def applyInternal(plan: SparkPlan, isSubquery: Boolean): SparkPlan = plan match { case _ if !conf.adaptiveExecutionEnabled => plan case _: ExecutedCommandExec => plan + case _: CommandResultExec => plan case c: DataWritingCommandExec => c.copy(child = apply(c.child)) case c: V2CommandExec => c.withNewChildren(c.children.map(apply)) case _ if shouldApplyAQE(plan, isSubquery) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 7f4f816d328da..42ac51f8fb9bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} import org.apache.spark.sql.catalyst.trees.LeafLike import org.apache.spark.sql.connector.ExternalCommandRunner -import org.apache.spark.sql.execution.{ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{CommandExecutionMode, ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.IncrementalExecution import org.apache.spark.sql.types._ @@ -163,7 +163,8 @@ case class ExplainCommand( // Run through the optimizer to generate the physical plan. override def run(sparkSession: SparkSession): Seq[Row] = try { - val outputString = sparkSession.sessionState.executePlan(logicalPlan).explainString(mode) + val outputString = sparkSession.sessionState.executePlan(logicalPlan, CommandExecutionMode.SKIP) + .explainString(mode) Seq(Row(outputString)) } catch { case NonFatal(cause) => ("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index d69d2b0f80973..3caf850bfb07f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{CommandExecutionMode, SparkPlan} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType @@ -195,7 +195,7 @@ case class CreateDataSourceTableAsSelectCommand( sessionState.executePlan(RepairTableCommand( table.identifier, enableAddPartitions = true, - enableDropPartitions = false)).toRdd + enableDropPartitions = false), CommandExecutionMode.SKIP).toRdd case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/CommandResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/CommandResult.scala new file mode 100644 index 0000000000000..23f5c5f1982fc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/CommandResult.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils +import org.apache.spark.sql.execution.SparkPlan + +/** + * Logical plan node for holding data from a command. + * + * `commandLogicalPlan` and `commandPhysicalPlan` are just used to display the plan tree + * for EXPLAIN. + * `rows` may not be serializable and ideally we should not send `rows` to the executors. + * Thus marking them as transient. + */ +case class CommandResult( + output: Seq[Attribute], + @transient commandLogicalPlan: LogicalPlan, + @transient commandPhysicalPlan: SparkPlan, + @transient rows: Seq[InternalRow]) extends LeafNode { + override def innerChildren: Seq[QueryPlan[_]] = Seq(commandLogicalPlan) + + override def computeStats(): Statistics = + Statistics(sizeInBytes = EstimationUtils.getSizePerRow(output) * rows.length) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 28805d611ceb7..8289819644a0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager -import org.apache.spark.sql.execution.{ColumnarRule, QueryExecution, SparkOptimizer, SparkPlan, SparkPlanner, SparkSqlParser} +import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode, QueryExecution, SparkOptimizer, SparkPlan, SparkPlanner, SparkSqlParser} import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.execution.command.CommandCheck @@ -310,9 +310,9 @@ abstract class BaseSessionStateBuilder( /** * Create a query execution object. */ - protected def createQueryExecution: LogicalPlan => QueryExecution = { plan => - new QueryExecution(session, plan) - } + protected def createQueryExecution: + (LogicalPlan, CommandExecutionMode.Value) => QueryExecution = + (plan, mode) => new QueryExecution(session, plan, mode = mode) /** * Interface to start and stop streaming queries. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 79fbca6338905..cdf764a7317dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -76,7 +76,7 @@ private[sql] class SessionState( val streamingQueryManagerBuilder: () => StreamingQueryManager, val listenerManager: ExecutionListenerManager, resourceLoaderBuilder: () => SessionResourceLoader, - createQueryExecution: LogicalPlan => QueryExecution, + createQueryExecution: (LogicalPlan, CommandExecutionMode.Value) => QueryExecution, createClone: (SparkSession, SessionState) => SessionState, val columnarRules: Seq[ColumnarRule], val queryStagePrepRules: Seq[Rule[SparkPlan]]) { @@ -119,7 +119,10 @@ private[sql] class SessionState( // Helper methods, partially leftover from pre-2.0 days // ------------------------------------------------------ - def executePlan(plan: LogicalPlan): QueryExecution = createQueryExecution(plan) + def executePlan( + plan: LogicalPlan, + mode: CommandExecutionMode.Value = CommandExecutionMode.ALL): QueryExecution = + createQueryExecution(plan, mode) } private[sql] object SessionState { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala index 51d734279414a..be0dae2563d88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala @@ -178,7 +178,7 @@ class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession { inputData.write.format(format).save(path.getCanonicalPath) sparkContext.listenerBus.waitUntilEmpty() assert(commands.length == 1) - assert(commands.head._1 == "save") + assert(commands.head._1 == "command") assert(commands.head._2.isInstanceOf[InsertIntoHadoopFsRelationCommand]) assert(commands.head._2.asInstanceOf[InsertIntoHadoopFsRelationCommand] .fileFormat.isInstanceOf[ParquetFileFormat]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index db4a9c153c0ff..945a35a968de3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog} import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NullOrdering, SortDirection, SortOrder} import org.apache.spark.sql.connector.expressions.LogicalExpressions._ -import org.apache.spark.sql.execution.{QueryExecution, SortExec, SparkPlan} +import org.apache.spark.sql.execution.{CommandResultExec, QueryExecution, SortExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike @@ -778,7 +778,8 @@ class WriteDistributionAndOrderingSuite sparkContext.listenerBus.waitUntilEmpty() - executedPlan match { + assert(executedPlan.isInstanceOf[CommandResultExec]) + executedPlan.asInstanceOf[CommandResultExec].commandPhysicalPlan match { case w: V2TableWriteExec => stripAQEPlan(w.query) case _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index eb93a5eca6560..e67d52712f043 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -19,9 +19,12 @@ package org.apache.spark.sql.execution import scala.io.Source import org.apache.spark.sql.{AnalysisException, FastOperator} +import org.apache.spark.sql.catalyst.analysis.UnresolvedNamespace import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project, ShowTables, SubqueryAlias} import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.execution.command.{ExecutedCommandExec, ShowTablesCommand} +import org.apache.spark.sql.expressions.CommandResult import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -236,4 +239,30 @@ class QueryExecutionSuite extends SharedSparkSession { assert(df.queryExecution.optimizedPlan.toString.startsWith("Relation default.spark_34129[")) } } + + test("SPARK-35378: Eagerly execute non-root Command") { + def qe(logicalPlan: LogicalPlan): QueryExecution = new QueryExecution(spark, logicalPlan) + + val showTables = ShowTables(UnresolvedNamespace(Seq.empty[String]), None) + val showTablesQe = qe(showTables) + assert(showTablesQe.commandExecuted.isInstanceOf[CommandResult]) + assert(showTablesQe.executedPlan.isInstanceOf[CommandResultExec]) + val showTablesResultExec = showTablesQe.executedPlan.asInstanceOf[CommandResultExec] + assert(showTablesResultExec.commandPhysicalPlan.isInstanceOf[ExecutedCommandExec]) + assert(showTablesResultExec.commandPhysicalPlan.asInstanceOf[ExecutedCommandExec] + .cmd.isInstanceOf[ShowTablesCommand]) + + val project = Project(showTables.output, SubqueryAlias("s", showTables)) + val projectQe = qe(project) + assert(projectQe.commandExecuted.isInstanceOf[Project]) + assert(projectQe.commandExecuted.children.length == 1) + assert(projectQe.commandExecuted.children(0).isInstanceOf[SubqueryAlias]) + assert(projectQe.commandExecuted.children(0).children.length == 1) + assert(projectQe.commandExecuted.children(0).children(0).isInstanceOf[CommandResult]) + assert(projectQe.executedPlan.isInstanceOf[CommandResultExec]) + val cmdResultExec = projectQe.executedPlan.asInstanceOf[CommandResultExec] + assert(cmdResultExec.commandPhysicalPlan.isInstanceOf[ExecutedCommandExec]) + assert(cmdResultExec.commandPhysicalPlan.asInstanceOf[ExecutedCommandExec] + .cmd.isInstanceOf[ShowTablesCommand]) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 90ac62c39f48a..d35e2e2fcdeb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListe import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.execution.{LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.noop.NoopDataSource import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec @@ -1035,8 +1035,11 @@ class AdaptiveQueryExecSuite SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { withTable("t1") { val plan = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").queryExecution.executedPlan - assert(plan.isInstanceOf[DataWritingCommandExec]) - assert(plan.asInstanceOf[DataWritingCommandExec].child.isInstanceOf[AdaptiveSparkPlanExec]) + assert(plan.isInstanceOf[CommandResultExec]) + val commandResultExec = plan.asInstanceOf[CommandResultExec] + assert(commandResultExec.commandPhysicalPlan.isInstanceOf[DataWritingCommandExec]) + assert(commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec] + .child.isInstanceOf[AdaptiveSparkPlanExec]) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 179f6171464bb..922e7b89dc01c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -25,7 +25,7 @@ import scala.util.Random import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.{FilterExec, RangeExec, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommandExec @@ -791,9 +791,10 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils test("SPARK-34567: Add metrics for CTAS operator") { withTable("t") { val df = sql("CREATE TABLE t USING PARQUET AS SELECT 1 as a") + assert(df.queryExecution.executedPlan.isInstanceOf[CommandResultExec]) + val commandResultExec = df.queryExecution.executedPlan.asInstanceOf[CommandResultExec] val dataWritingCommandExec = - df.queryExecution.executedPlan.asInstanceOf[DataWritingCommandExec] - dataWritingCommandExec.executeCollect() + commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec] val createTableAsSelect = dataWritingCommandExec.cmd assert(createTableAsSelect.metrics.contains("numFiles")) assert(createTableAsSelect.metrics("numFiles").value == 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index b3d29df1b29bc..f6a6e27fca1d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -20,14 +20,13 @@ package org.apache.spark.sql.util import scala.collection.mutable.ArrayBuffer import org.apache.spark._ -import org.apache.spark.sql.{functions, AnalysisException, Dataset, QueryTest, Row, SparkSession} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.{functions, Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoStatement, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.execution.{QueryExecution, QueryExecutionException, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.command.LeafRunnableCommand -import org.apache.spark.sql.execution.datasources.{CreateTable, InsertIntoHadoopFsRelationCommand} +import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, LeafRunnableCommand} +import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StringType @@ -194,7 +193,7 @@ class DataFrameCallbackSuite extends QueryTest spark.range(10).write.format("json").save(path.getCanonicalPath) sparkContext.listenerBus.waitUntilEmpty() assert(commands.length == 1) - assert(commands.head._1 == "save") + assert(commands.head._1 == "command") assert(commands.head._2.isInstanceOf[InsertIntoHadoopFsRelationCommand]) assert(commands.head._2.asInstanceOf[InsertIntoHadoopFsRelationCommand] .fileFormat.isInstanceOf[JsonFileFormat]) @@ -205,10 +204,10 @@ class DataFrameCallbackSuite extends QueryTest spark.range(10).write.insertInto("tab") sparkContext.listenerBus.waitUntilEmpty() assert(commands.length == 3) - assert(commands(2)._1 == "insertInto") - assert(commands(2)._2.isInstanceOf[InsertIntoStatement]) - assert(commands(2)._2.asInstanceOf[InsertIntoStatement].table - .asInstanceOf[UnresolvedRelation].multipartIdentifier == Seq("tab")) + assert(commands(2)._1 == "command") + assert(commands(2)._2.isInstanceOf[InsertIntoHadoopFsRelationCommand]) + assert(commands(2)._2.asInstanceOf[InsertIntoHadoopFsRelationCommand] + .catalogTable.get.identifier.identifier == "tab") } // exiting withTable adds commands(3) via onSuccess (drops tab) @@ -216,19 +215,21 @@ class DataFrameCallbackSuite extends QueryTest spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab") sparkContext.listenerBus.waitUntilEmpty() assert(commands.length == 5) - assert(commands(4)._1 == "saveAsTable") - assert(commands(4)._2.isInstanceOf[CreateTable]) - assert(commands(4)._2.asInstanceOf[CreateTable].tableDesc.partitionColumnNames == Seq("p")) + assert(commands(4)._1 == "command") + assert(commands(4)._2.isInstanceOf[CreateDataSourceTableAsSelectCommand]) + assert(commands(4)._2.asInstanceOf[CreateDataSourceTableAsSelectCommand] + .table.partitionColumnNames == Seq("p")) } withTable("tab") { sql("CREATE TABLE tab(i long) using parquet") - val e = intercept[AnalysisException] { - spark.range(10).select($"id", $"id").write.insertInto("tab") + spark.udf.register("illegalUdf", udf((value: Long) => value / 0)) + val e = intercept[SparkException] { + spark.range(10).selectExpr("illegalUdf(id)").write.insertInto("tab") } sparkContext.listenerBus.waitUntilEmpty() assert(exceptions.length == 1) - assert(exceptions.head._1 == "insertInto") + assert(exceptions.head._1 == "command") assert(exceptions.head._2 == e) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala index a2de43d737704..2236f236d739d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.execution.CommandResultExec import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.metric.SQLMetricsTestUtils @@ -42,9 +43,10 @@ class SQLMetricsSuite extends SQLMetricsTestUtils with TestHiveSingleton withSQLConf(HiveUtils.CONVERT_METASTORE_CTAS.key -> canOptimized.toString) { withTable("t") { val df = sql(s"CREATE TABLE t STORED AS PARQUET AS SELECT 1 as a") + assert(df.queryExecution.executedPlan.isInstanceOf[CommandResultExec]) + val commandResultExec = df.queryExecution.executedPlan.asInstanceOf[CommandResultExec] val dataWritingCommandExec = - df.queryExecution.executedPlan.asInstanceOf[DataWritingCommandExec] - dataWritingCommandExec.executeCollect() + commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec] val createTableAsSelect = dataWritingCommandExec.cmd if (canOptimized) { assert(createTableAsSelect.isInstanceOf[OptimizedCreateHiveTableAsSelectCommand]) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala index ec73d7f71887f..3769de07d8a37 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ -import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} +import org.apache.spark.sql.execution.{CommandExecutionMode, QueryExecution, SQLExecution} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf, WithTestConf} @@ -584,8 +584,9 @@ private[hive] class TestHiveSparkSession( private[hive] class TestHiveQueryExecution( sparkSession: TestHiveSparkSession, - logicalPlan: LogicalPlan) - extends QueryExecution(sparkSession, logicalPlan) with Logging { + logicalPlan: LogicalPlan, + mode: CommandExecutionMode.Value = CommandExecutionMode.ALL) + extends QueryExecution(sparkSession, logicalPlan, mode = mode) with Logging { def this(sparkSession: TestHiveSparkSession, sql: String) = { this(sparkSession, sparkSession.sessionState.sqlParser.parsePlan(sql)) @@ -661,9 +662,10 @@ private[sql] class TestHiveSessionStateBuilder( override def overrideConfs: Map[String, String] = TestHiveContext.overrideConfs - override def createQueryExecution: (LogicalPlan) => QueryExecution = { plan => - new TestHiveQueryExecution(session.asInstanceOf[TestHiveSparkSession], plan) - } + override def createQueryExecution: + (LogicalPlan, CommandExecutionMode.Value) => QueryExecution = + (plan, mode) => + new TestHiveQueryExecution(session.asInstanceOf[TestHiveSparkSession], plan, mode) override protected def newBuilder: NewBuilder = new TestHiveSessionStateBuilder(_, _) }