diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/AppendBatchResizeForShuffleInputAndOutput.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/AppendBatchResizeForShuffleInputAndOutput.scala index 7d6309bf93f3..92519eecf77d 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/AppendBatchResizeForShuffleInputAndOutput.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/AppendBatchResizeForShuffleInputAndOutput.scala @@ -22,6 +22,7 @@ import org.apache.gluten.execution.VeloxResizeBatchesExec import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{ColumnarShuffleExchangeExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec /** * Try to append [[VeloxResizeBatchesExec]] for shuffle input and output to make the batch sizes in @@ -49,7 +50,16 @@ case class AppendBatchResizeForShuffleInputAndOutput() extends Rule[SparkPlan] { if resizeBatchesShuffleOutputEnabled && shuffle.shuffleWriterType.requiresResizingShuffleOutput => VeloxResizeBatchesExec(a, range.min, range.max) - // Since it's transformed in a bottom to up order, so we may first encountered + case a @ AQEShuffleReadExec( + ShuffleQueryStageExec( + _, + ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeExec), + _), + _) + if resizeBatchesShuffleOutputEnabled && + shuffle.shuffleWriterType.requiresResizingShuffleOutput => + VeloxResizeBatchesExec(a, range.min, range.max) + // Since it's transformed in a bottom to up order, so we may first encounter // ShuffeQueryStageExec, which is transformed to VeloxResizeBatchesExec(ShuffeQueryStageExec), // then we see AQEShuffleReadExec case a @ AQEShuffleReadExec( @@ -61,10 +71,29 @@ case class AppendBatchResizeForShuffleInputAndOutput() extends Rule[SparkPlan] { if resizeBatchesShuffleOutputEnabled && shuffle.shuffleWriterType.requiresResizingShuffleOutput => VeloxResizeBatchesExec(a.copy(child = s), range.min, range.max) + case a @ AQEShuffleReadExec( + VeloxResizeBatchesExec( + s @ ShuffleQueryStageExec( + _, + ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeExec), + _), + _, + _), + _) + if resizeBatchesShuffleOutputEnabled && + shuffle.shuffleWriterType.requiresResizingShuffleOutput => + VeloxResizeBatchesExec(a.copy(child = s), range.min, range.max) case s @ ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeExec, _) if resizeBatchesShuffleOutputEnabled && shuffle.shuffleWriterType.requiresResizingShuffleOutput => VeloxResizeBatchesExec(s, range.min, range.max) + case s @ ShuffleQueryStageExec( + _, + ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeExec), + _) + if resizeBatchesShuffleOutputEnabled && + shuffle.shuffleWriterType.requiresResizingShuffleOutput => + VeloxResizeBatchesExec(s, range.min, range.max) } } }