Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
*/
package org.apache.spark.shuffle

import org.apache.gluten.config.ReservedKeys.{GLUTEN_RSS_SORT_SHUFFLE_WRITER, GLUTEN_SORT_SHUFFLE_WRITER}
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.config.ReservedKeys.{GLUTEN_RSS_SORT_SHUFFLE_WRITER, GLUTEN_SORT_SHUFFLE_WRITER}
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
import org.apache.gluten.runtime.Runtimes
import org.apache.gluten.utils.ArrowAbiUtil
Expand All @@ -37,7 +38,6 @@ import org.apache.spark.task.{TaskResource, TaskResources}
import org.apache.arrow.c.ArrowSchema
import org.apache.arrow.memory.BufferAllocator
import org.apache.celeborn.client.read.CelebornInputStream
import org.apache.gluten.config.GlutenConfig

import java.io._
import java.nio.ByteBuffer
Expand Down Expand Up @@ -120,12 +120,8 @@ private class CelebornColumnarBatchSerializerInstance(
private class TaskDeserializationStream(in: InputStream)
extends DeserializationStream
with TaskResource {
private val byteIn: JniByteInputStream = JniByteInputStreams.create(in)
private val wrappedOut: ColumnarBatchOutIterator = new ColumnarBatchOutIterator(
runtime,
ShuffleReaderJniWrapper
.create(runtime)
.readStream(shuffleReaderHandle, byteIn))
private var byteIn: JniByteInputStream = _
private var wrappedOut: ColumnarBatchOutIterator = _

private var cb: ColumnarBatch = _

Expand Down Expand Up @@ -191,6 +187,7 @@ private class CelebornColumnarBatchSerializerInstance(

@throws(classOf[EOFException])
override def readValue[T: ClassTag](): T = {
initStream();
if (cb != null) {
cb.close()
cb = null
Expand Down Expand Up @@ -245,13 +242,26 @@ private class CelebornColumnarBatchSerializerInstance(
readBatchNumRows.set(numRowsTotal.toDouble / numBatchesTotal)
}
numOutputRows += numRowsTotal
wrappedOut.close()
byteIn.close()
if (byteIn != null) {
wrappedOut.close()
byteIn.close()
}
if (cb != null) {
cb.close()
}
}

private def initStream(): Unit = {
if (byteIn == null) {
byteIn = JniByteInputStreams.create(in)
wrappedOut = new ColumnarBatchOutIterator(
runtime,
ShuffleReaderJniWrapper
.create(runtime)
.readStream(shuffleReaderHandle, byteIn))
}
}

override def resourceName(): String = getClass.getName
}

Expand Down
Loading