Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@
*/
package org.apache.spark.shuffle

import org.apache.gluten.config.ReservedKeys.{GLUTEN_RSS_SORT_SHUFFLE_WRITER, GLUTEN_SORT_SHUFFLE_WRITER}
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
import org.apache.gluten.runtime.Runtimes
import org.apache.gluten.utils.ArrowAbiUtil
import org.apache.gluten.vectorized._

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.SHUFFLE_COMPRESS
Expand All @@ -33,42 +31,45 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.utils.SparkSchemaUtil
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.task.{TaskResource, TaskResources}

import org.apache.arrow.c.ArrowSchema
import org.apache.arrow.memory.BufferAllocator
import org.apache.celeborn.client.read.CelebornInputStream
import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.config.{GlutenConfig, ReservedKeys}

import java.io._
import java.nio.ByteBuffer
import java.util.UUID
import java.util.concurrent.atomic.AtomicBoolean

import scala.reflect.ClassTag

class CelebornColumnarBatchSerializer(
schema: StructType,
readBatchNumRows: SQLMetric,
numOutputRows: SQLMetric)
numOutputRows: SQLMetric,
isSort: Boolean)
extends Serializer
with Serializable {

/** Creates a new [[SerializerInstance]]. */
override def newInstance(): SerializerInstance = {
new CelebornColumnarBatchSerializerInstance(schema, readBatchNumRows, numOutputRows)
new CelebornColumnarBatchSerializerInstance(schema, readBatchNumRows, numOutputRows, isSort)
}
}

private class CelebornColumnarBatchSerializerInstance(
schema: StructType,
readBatchNumRows: SQLMetric,
numOutputRows: SQLMetric)
numOutputRows: SQLMetric,
isSort: Boolean)
extends SerializerInstance
with Logging {

private val runtime =
Runtimes.contextInstance(BackendsApiManager.getBackendName, "CelebornShuffleReader")

private val shuffleWriterType =
if (isSort) ReservedKeys.GLUTEN_SORT_SHUFFLE_WRITER else ReservedKeys.GLUTEN_HASH_SHUFFLE_WRITER

private val shuffleReaderHandle = {
val allocator: BufferAllocator = ArrowBufferAllocators
.contextInstance(classOf[CelebornColumnarBatchSerializerInstance].getSimpleName)
Expand All @@ -86,8 +87,6 @@ private class CelebornColumnarBatchSerializerInstance(
}
val compressionCodecBackend =
GlutenConfig.get.columnarShuffleCodecBackend.orNull
val shuffleWriterType = GlutenConfig.get.celebornShuffleWriterType
.replace(GLUTEN_SORT_SHUFFLE_WRITER, GLUTEN_RSS_SORT_SHUFFLE_WRITER)
val jniWrapper = ShuffleReaderJniWrapper.create(runtime)
val batchSize = GlutenConfig.get.maxBatchSize
val bufferSize = GlutenConfig.get.columnarShuffleReaderBufferSize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,16 @@ import org.apache.gluten.columnarbatch.ColumnarBatches
import org.apache.gluten.memory.memtarget.{MemoryTarget, Spiller, Spillers}
import org.apache.gluten.runtime.Runtimes
import org.apache.gluten.vectorized._

import org.apache.spark._
import org.apache.spark.internal.config.{SHUFFLE_SORT_INIT_BUFFER_SIZE, SHUFFLE_SORT_USE_RADIXSORT}
import org.apache.spark.memory.SparkMemoryUtil
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.SparkResourceUtil

import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.config.ReservedKeys
import org.apache.gluten.config.{GlutenConfig, ReservedKeys}

import java.io.IOException

Expand All @@ -51,7 +48,7 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
celebornConf,
client,
writeMetrics) {
private val isSort = !ReservedKeys.GLUTEN_HASH_SHUFFLE_WRITER.equals(shuffleWriterType)
private val isSort = handle.dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]].isSort

private val runtime =
Runtimes.contextInstance(BackendsApiManager.getBackendName, "CelebornShuffleWriter")
Expand All @@ -60,6 +57,9 @@ class VeloxCelebornColumnarShuffleWriter[K, V](

private var splitResult: GlutenSplitResult = _

override val shuffleWriterType: String =
if (isSort) ReservedKeys.GLUTEN_SORT_SHUFFLE_WRITER else ReservedKeys.GLUTEN_HASH_SHUFFLE_WRITER

private def availableOffHeapPerTask(): Long = {
val perTask =
SparkMemoryUtil.getCurrentAvailableOffHeapMemory / SparkResourceUtil.getTaskSlots(conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,9 +549,9 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
lazy val isCelebornSortBasedShuffle = conf.isUseCelebornShuffleManager &&
conf.celebornShuffleWriterType == ReservedKeys.GLUTEN_SORT_SHUFFLE_WRITER
partitioning != SinglePartition &&
(partitioning.numPartitions >= GlutenConfig.get.columnarShuffleSortPartitionsThreshold ||
((partitioning.numPartitions >= GlutenConfig.get.columnarShuffleSortPartitionsThreshold ||
output.size >= GlutenConfig.get.columnarShuffleSortColumnsThreshold) ||
isCelebornSortBasedShuffle
isCelebornSortBasedShuffle)
}

/**
Expand Down Expand Up @@ -606,8 +606,14 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
if (GlutenConfig.get.isUseCelebornShuffleManager) {
val clazz = ClassUtils.getClass("org.apache.spark.shuffle.CelebornColumnarBatchSerializer")
val constructor =
clazz.getConstructor(classOf[StructType], classOf[SQLMetric], classOf[SQLMetric])
constructor.newInstance(schema, readBatchNumRows, numOutputRows).asInstanceOf[Serializer]
clazz.getConstructor(
classOf[StructType],
classOf[SQLMetric],
classOf[SQLMetric],
classOf[Boolean])
constructor
.newInstance(schema, readBatchNumRows, numOutputRows, isSort: java.lang.Boolean)
.asInstanceOf[Serializer]
} else {
new ColumnarBatchSerializer(
schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.shuffle
import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.config.ReservedKeys
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.SHUFFLE_COMPRESS
Expand Down Expand Up @@ -70,9 +69,7 @@ abstract class CelebornColumnarShuffleWriter[K, V](
protected val clientPushSortMemoryThreshold: Long = celebornConf.clientPushSortMemoryThreshold

protected val shuffleWriterType: String =
celebornConf.shuffleWriterMode.name
.toLowerCase(Locale.ROOT)
.replace(ReservedKeys.GLUTEN_SORT_SHUFFLE_WRITER, ReservedKeys.GLUTEN_RSS_SORT_SHUFFLE_WRITER)
celebornConf.shuffleWriterMode.name.toLowerCase(Locale.ROOT)

protected val celebornPartitionPusher = new CelebornPartitionPusher(
shuffleId,
Expand Down
Loading