diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 86807914c2362..557235852fbdb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3586,7 +3586,8 @@ object SQLConf { val CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED = buildConf("spark.sql.classic.shuffleDependency.fileCleanup.enabled") .doc("When enabled, shuffle files will be cleaned up at the end of classic " + - "SQL executions.") + "SQL executions. Note that this cleanup may cause stage retries and regenerate " + + "shuffle files if the same dataframe reference is executed again.") .version("4.1.0") .booleanConf .createWithDefault(Utils.isTesting) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 9f59bded94fe1..6d27740bcea90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -150,7 +150,8 @@ class QueryExecution( // with the rest of processing of the root plan being just outputting command results, // for eagerly executed commands we mark this place as beginning of execution. tracker.setReadyForExecution() - val qe = sparkSession.sessionState.executePlan(p, mode) + val qe = new QueryExecution(sparkSession, p, mode = mode, + shuffleCleanupMode = shuffleCleanupMode) val result = QueryExecution.withInternalError(s"Eagerly executed $name failed.") { SQLExecution.withNewExecutionId(qe, Some(name)) { qe.executedPlan.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index c5c2f9bb6a6f6..1cab0f8d35af5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -30,6 +30,8 @@ import org.apache.spark.internal.config.{SPARK_DRIVER_PREFIX, SPARK_EXECUTOR_PRE import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.command.DataWritingCommandExec +import org.apache.spark.sql.execution.datasources.v2.V2CommandExec import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} import org.apache.spark.sql.internal.SQLConf @@ -68,6 +70,17 @@ object SQLExecution extends Logging { } } + private def extractShuffleIds(plan: SparkPlan): Seq[Int] = { + plan match { + case ae: AdaptiveSparkPlanExec => + ae.context.shuffleIds.asScala.keys.toSeq + case nonAdaptivePlan => + nonAdaptivePlan.collect { + case exec: ShuffleExchangeLike => exec.shuffleId + } + } + } + /** * Wrap an action that will execute "queryExecution" to track all Spark jobs in the body so that * we can connect them with an execution. @@ -177,13 +190,12 @@ object SQLExecution extends Logging { if (queryExecution.shuffleCleanupMode != DoNotCleanup && isExecutedPlanAvailable) { val shuffleIds = queryExecution.executedPlan match { - case ae: AdaptiveSparkPlanExec => - ae.context.shuffleIds.asScala.keys - case nonAdaptivePlan => - nonAdaptivePlan.collect { - case exec: ShuffleExchangeLike => - exec.shuffleId - } + case command: V2CommandExec => + command.children.flatMap(extractShuffleIds) + case dataWritingCommand: DataWritingCommandExec => + extractShuffleIds(dataWritingCommand.child) + case plan => + extractShuffleIds(plan) } shuffleIds.foreach { shuffleId => queryExecution.shuffleCleanupMode match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index a4e4a407255c0..0d0a0f2f31007 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -20,7 +20,7 @@ import scala.collection.mutable import scala.io.Source import scala.util.Try -import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, FastOperator} +import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, FastOperator, SaveMode} import org.apache.spark.sql.catalyst.{QueryPlanningTracker, QueryPlanningTrackerCallback, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{CurrentNamespace, UnresolvedFunction, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions.{Alias, UnsafeRow} @@ -327,6 +327,41 @@ class QueryExecutionSuite extends SharedSparkSession { } } + test("SPARK-53413: Cleanup shuffle dependencies for commands") { + Seq(true, false).foreach { adaptiveEnabled => { + withSQLConf((SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, adaptiveEnabled.toString), + (SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key, true.toString)) { + val plan = spark.range(100).repartition(10).logicalPlan + val df = Dataset.ofRows(spark, plan) + df.write.format("noop").mode(SaveMode.Overwrite).save() + + val blockManager = spark.sparkContext.env.blockManager + assert(blockManager.migratableResolver.getStoredShuffles().isEmpty) + assert(blockManager.diskBlockManager.getAllBlocks().isEmpty) + } + } + } + } + + test("SPARK-53413: Cleanup shuffle dependencies for DataWritingCommandExec") { + withTempDir { dir => + Seq(true, false).foreach { adaptiveEnabled => { + withSQLConf((SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, adaptiveEnabled.toString), + (SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key, true.toString)) { + val plan = spark.range(100).repartition(10).logicalPlan + val df = Dataset.ofRows(spark, plan) + // V1 API write + df.write.format("csv").mode(SaveMode.Overwrite).save(dir.getCanonicalPath) + + val blockManager = spark.sparkContext.env.blockManager + assert(blockManager.migratableResolver.getStoredShuffles().isEmpty) + assert(blockManager.diskBlockManager.getAllBlocks().isEmpty) + } + } + } + } + } + test("SPARK-47764: Cleanup shuffle dependencies - DoNotCleanup mode") { Seq(true, false).foreach { adaptiveEnabled => { withSQLConf((SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, adaptiveEnabled.toString)) {