From fd1dece7a6ef237b4cd3a2de8b12fe0fb2dff513 Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Thu, 3 Apr 2025 17:30:09 +0100 Subject: [PATCH] fixup --- ...VeloxCelebornColumnarBatchSerializer.scala | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala index dc314ba44a9e..fbf1c673036d 100644 --- a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala +++ b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala @@ -16,8 +16,9 @@ */ package org.apache.spark.shuffle -import org.apache.gluten.config.ReservedKeys.{GLUTEN_RSS_SORT_SHUFFLE_WRITER, GLUTEN_SORT_SHUFFLE_WRITER} import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.config.GlutenConfig +import org.apache.gluten.config.ReservedKeys.{GLUTEN_RSS_SORT_SHUFFLE_WRITER, GLUTEN_SORT_SHUFFLE_WRITER} import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators import org.apache.gluten.runtime.Runtimes import org.apache.gluten.utils.ArrowAbiUtil @@ -37,7 +38,6 @@ import org.apache.spark.task.{TaskResource, TaskResources} import org.apache.arrow.c.ArrowSchema import org.apache.arrow.memory.BufferAllocator import org.apache.celeborn.client.read.CelebornInputStream -import org.apache.gluten.config.GlutenConfig import java.io._ import java.nio.ByteBuffer @@ -120,12 +120,8 @@ private class CelebornColumnarBatchSerializerInstance( private class TaskDeserializationStream(in: InputStream) extends DeserializationStream with TaskResource { - private val byteIn: JniByteInputStream = JniByteInputStreams.create(in) - private val wrappedOut: ColumnarBatchOutIterator = new ColumnarBatchOutIterator( - runtime, - ShuffleReaderJniWrapper - .create(runtime) - .readStream(shuffleReaderHandle, byteIn)) + private var byteIn: JniByteInputStream = _ + private var wrappedOut: ColumnarBatchOutIterator = _ private var cb: ColumnarBatch = _ @@ -191,6 +187,7 @@ private class CelebornColumnarBatchSerializerInstance( @throws(classOf[EOFException]) override def readValue[T: ClassTag](): T = { + initStream(); if (cb != null) { cb.close() cb = null @@ -245,13 +242,26 @@ private class CelebornColumnarBatchSerializerInstance( readBatchNumRows.set(numRowsTotal.toDouble / numBatchesTotal) } numOutputRows += numRowsTotal - wrappedOut.close() - byteIn.close() + if (byteIn != null) { + wrappedOut.close() + byteIn.close() + } if (cb != null) { cb.close() } } + private def initStream(): Unit = { + if (byteIn == null) { + byteIn = JniByteInputStreams.create(in) + wrappedOut = new ColumnarBatchOutIterator( + runtime, + ShuffleReaderJniWrapper + .create(runtime) + .readStream(shuffleReaderHandle, byteIn)) + } + } + override def resourceName(): String = getClass.getName }