diff --git a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala index dc314ba44a9e..f1aeb1ce077d 100644 --- a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala +++ b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala @@ -16,13 +16,11 @@ */ package org.apache.spark.shuffle -import org.apache.gluten.config.ReservedKeys.{GLUTEN_RSS_SORT_SHUFFLE_WRITER, GLUTEN_SORT_SHUFFLE_WRITER} import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators import org.apache.gluten.runtime.Runtimes import org.apache.gluten.utils.ArrowAbiUtil import org.apache.gluten.vectorized._ - import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.internal.config.SHUFFLE_COMPRESS @@ -33,42 +31,45 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.utils.SparkSchemaUtil import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.task.{TaskResource, TaskResources} - import org.apache.arrow.c.ArrowSchema import org.apache.arrow.memory.BufferAllocator import org.apache.celeborn.client.read.CelebornInputStream -import org.apache.gluten.config.GlutenConfig +import org.apache.gluten.config.{GlutenConfig, ReservedKeys} import java.io._ import java.nio.ByteBuffer import java.util.UUID import java.util.concurrent.atomic.AtomicBoolean - import scala.reflect.ClassTag class CelebornColumnarBatchSerializer( schema: StructType, readBatchNumRows: SQLMetric, - numOutputRows: SQLMetric) + numOutputRows: SQLMetric, + isSort: Boolean) extends Serializer with Serializable { /** Creates a new [[SerializerInstance]]. */ override def newInstance(): SerializerInstance = { - new CelebornColumnarBatchSerializerInstance(schema, readBatchNumRows, numOutputRows) + new CelebornColumnarBatchSerializerInstance(schema, readBatchNumRows, numOutputRows, isSort) } } private class CelebornColumnarBatchSerializerInstance( schema: StructType, readBatchNumRows: SQLMetric, - numOutputRows: SQLMetric) + numOutputRows: SQLMetric, + isSort: Boolean) extends SerializerInstance with Logging { private val runtime = Runtimes.contextInstance(BackendsApiManager.getBackendName, "CelebornShuffleReader") + private val shuffleWriterType = + if (isSort) ReservedKeys.GLUTEN_SORT_SHUFFLE_WRITER else ReservedKeys.GLUTEN_HASH_SHUFFLE_WRITER + private val shuffleReaderHandle = { val allocator: BufferAllocator = ArrowBufferAllocators .contextInstance(classOf[CelebornColumnarBatchSerializerInstance].getSimpleName) @@ -86,8 +87,6 @@ private class CelebornColumnarBatchSerializerInstance( } val compressionCodecBackend = GlutenConfig.get.columnarShuffleCodecBackend.orNull - val shuffleWriterType = GlutenConfig.get.celebornShuffleWriterType - .replace(GLUTEN_SORT_SHUFFLE_WRITER, GLUTEN_RSS_SORT_SHUFFLE_WRITER) val jniWrapper = ShuffleReaderJniWrapper.create(runtime) val batchSize = GlutenConfig.get.maxBatchSize val bufferSize = GlutenConfig.get.columnarShuffleReaderBufferSize diff --git a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala index 115982f48c0f..e5c8837fe7bd 100644 --- a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala +++ b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala @@ -21,7 +21,6 @@ import org.apache.gluten.columnarbatch.ColumnarBatches import org.apache.gluten.memory.memtarget.{MemoryTarget, Spiller, Spillers} import org.apache.gluten.runtime.Runtimes import org.apache.gluten.vectorized._ - import org.apache.spark._ import org.apache.spark.internal.config.{SHUFFLE_SORT_INIT_BUFFER_SIZE, SHUFFLE_SORT_USE_RADIXSORT} import org.apache.spark.memory.SparkMemoryUtil @@ -29,11 +28,9 @@ import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SparkResourceUtil - import org.apache.celeborn.client.ShuffleClient import org.apache.celeborn.common.CelebornConf -import org.apache.gluten.config.GlutenConfig -import org.apache.gluten.config.ReservedKeys +import org.apache.gluten.config.{GlutenConfig, ReservedKeys} import java.io.IOException @@ -51,7 +48,7 @@ class VeloxCelebornColumnarShuffleWriter[K, V]( celebornConf, client, writeMetrics) { - private val isSort = !ReservedKeys.GLUTEN_HASH_SHUFFLE_WRITER.equals(shuffleWriterType) + private val isSort = handle.dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]].isSort private val runtime = Runtimes.contextInstance(BackendsApiManager.getBackendName, "CelebornShuffleWriter") @@ -60,6 +57,9 @@ class VeloxCelebornColumnarShuffleWriter[K, V]( private var splitResult: GlutenSplitResult = _ + override val shuffleWriterType: String = + if (isSort) ReservedKeys.GLUTEN_SORT_SHUFFLE_WRITER else ReservedKeys.GLUTEN_HASH_SHUFFLE_WRITER + private def availableOffHeapPerTask(): Long = { val perTask = SparkMemoryUtil.getCurrentAvailableOffHeapMemory / SparkResourceUtil.getTaskSlots(conf) diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index ba56df411eea..00c8266273ab 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -549,9 +549,9 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { lazy val isCelebornSortBasedShuffle = conf.isUseCelebornShuffleManager && conf.celebornShuffleWriterType == ReservedKeys.GLUTEN_SORT_SHUFFLE_WRITER partitioning != SinglePartition && - (partitioning.numPartitions >= GlutenConfig.get.columnarShuffleSortPartitionsThreshold || + ((partitioning.numPartitions >= GlutenConfig.get.columnarShuffleSortPartitionsThreshold || output.size >= GlutenConfig.get.columnarShuffleSortColumnsThreshold) || - isCelebornSortBasedShuffle + isCelebornSortBasedShuffle) } /** @@ -606,8 +606,14 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { if (GlutenConfig.get.isUseCelebornShuffleManager) { val clazz = ClassUtils.getClass("org.apache.spark.shuffle.CelebornColumnarBatchSerializer") val constructor = - clazz.getConstructor(classOf[StructType], classOf[SQLMetric], classOf[SQLMetric]) - constructor.newInstance(schema, readBatchNumRows, numOutputRows).asInstanceOf[Serializer] + clazz.getConstructor( + classOf[StructType], + classOf[SQLMetric], + classOf[SQLMetric], + classOf[Boolean]) + constructor + .newInstance(schema, readBatchNumRows, numOutputRows, isSort: java.lang.Boolean) + .asInstanceOf[Serializer] } else { new ColumnarBatchSerializer( schema, diff --git a/gluten-celeborn/src-celeborn/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala b/gluten-celeborn/src-celeborn/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala index 7a514689e381..db2ae8a2d81a 100644 --- a/gluten-celeborn/src-celeborn/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala +++ b/gluten-celeborn/src-celeborn/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala @@ -19,7 +19,6 @@ package org.apache.spark.shuffle import org.apache.celeborn.client.ShuffleClient import org.apache.celeborn.common.CelebornConf import org.apache.gluten.config.GlutenConfig -import org.apache.gluten.config.ReservedKeys import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.SHUFFLE_COMPRESS @@ -70,9 +69,7 @@ abstract class CelebornColumnarShuffleWriter[K, V]( protected val clientPushSortMemoryThreshold: Long = celebornConf.clientPushSortMemoryThreshold protected val shuffleWriterType: String = - celebornConf.shuffleWriterMode.name - .toLowerCase(Locale.ROOT) - .replace(ReservedKeys.GLUTEN_SORT_SHUFFLE_WRITER, ReservedKeys.GLUTEN_RSS_SORT_SHUFFLE_WRITER) + celebornConf.shuffleWriterMode.name.toLowerCase(Locale.ROOT) protected val celebornPartitionPusher = new CelebornPartitionPusher( shuffleId,