From c800e6863b7c7a843c59b5143112f78274c0517b Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 13 Sep 2022 22:06:39 +0900 Subject: [PATCH 1/2] [SPARK-40414][SQL][PYSPARK] More generic type on PythonArrowInput and PythonArrowOutput --- .../execution/python/ArrowPythonRunner.scala | 4 +- .../python/CoGroupedArrowPythonRunner.scala | 2 +- .../execution/python/PythonArrowInput.scala | 50 ++++++++++++++----- .../execution/python/PythonArrowOutput.scala | 22 +++++--- 4 files changed, 56 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 137e2fe93c790..8467feb91d144 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -34,8 +34,8 @@ class ArrowPythonRunner( protected override val timeZoneId: String, protected override val workerConf: Map[String, String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets) - with PythonArrowInput - with PythonArrowOutput { + with BasicPythonArrowInput + with BasicPythonArrowOutput { override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index e3d8a943d8cf2..2661896ececc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -48,7 +48,7 @@ class CoGroupedArrowPythonRunner( conf: Map[String, String]) extends BasePythonRunner[ (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](funcs, evalType, argOffsets) - with PythonArrowOutput { + with BasicPythonArrowOutput { override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index 79365080f8cb3..896bc68ac29be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -26,21 +26,30 @@ import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{BasePythonRunner, PythonRDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils /** * A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from - * JVM (an iterator of internal rows) to Python (Arrow). + * JVM (an iterator of internal rows + additional data if required) to Python (Arrow). */ -private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[InternalRow], _] => +private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => + protected val sqlConf = SQLConf.get + protected val workerConf: Map[String, String] protected val schema: StructType protected val timeZoneId: String + protected def writeIteratorToArrowStream( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[IN]): Unit + protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { // Write config for the worker as a number of key -> value pairs of strings stream.writeInt(workerConf.size) @@ -53,7 +62,7 @@ private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[Interna protected override def newWriterThread( env: SparkEnv, worker: Socket, - inputIterator: Iterator[Iterator[InternalRow]], + inputIterator: Iterator[IN], partitionIndex: Int, context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { @@ -74,17 +83,8 @@ private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[Interna val writer = new ArrowStreamWriter(root, null, dataOut) writer.start() - while (inputIterator.hasNext) { - val nextBatch = inputIterator.next() - - while (nextBatch.hasNext) { - arrowWriter.write(nextBatch.next()) - } + writeIteratorToArrowStream(root, writer, dataOut, inputIterator) - arrowWriter.finish() - writer.writeBatch() - arrowWriter.reset() - } // end writes footer to the output stream and doesn't clean any resources. // It could throw exception if the output stream is closed, so it should be // in the try block. @@ -107,3 +107,27 @@ private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[Interna } } } + +private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[InternalRow]] { + self: BasePythonRunner[Iterator[InternalRow], _] => + + protected def writeIteratorToArrowStream( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[Iterator[InternalRow]]): Unit = { + val arrowWriter = ArrowWriter.create(root) + + while (inputIterator.hasNext) { + val nextBatch = inputIterator.next() + + while (nextBatch.hasNext) { + arrowWriter.write(nextBatch.next()) + } + + arrowWriter.finish() + writer.writeBatch() + arrowWriter.reset() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index d06a0d012990b..66b8297a5865f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -33,12 +33,14 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, Column /** * A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from - * Python (Arrow) to JVM (ColumnarBatch). + * Python (Arrow) to JVM (output type being deserialized from ColumnarBatch). */ -private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatch] => +private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[_, OUT] => protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { } + protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT + protected def newReaderIterator( stream: DataInputStream, writerThread: WriterThread, @@ -47,7 +49,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc worker: Socket, pid: Option[Int], releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[ColumnarBatch] = { + context: TaskContext): Iterator[OUT] = { new ReaderIterator( stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { @@ -74,7 +76,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc super.handleEndOfDataSection() } - protected override def read(): ColumnarBatch = { + protected override def read(): OUT = { if (writerThread.exception.isDefined) { throw writerThread.exception.get } @@ -84,7 +86,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc if (batchLoaded) { val batch = new ColumnarBatch(vectors) batch.setNumRows(root.getRowCount) - batch + deserializeColumnarBatch(batch, schema) } else { reader.close(false) allocator.close() @@ -108,7 +110,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc throw handlePythonException() case SpecialLengths.END_OF_DATA_SECTION => handleEndOfDataSection() - null + null.asInstanceOf[OUT] } } } catch handleException @@ -116,3 +118,11 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc } } } + +private[python] trait BasicPythonArrowOutput extends PythonArrowOutput[ColumnarBatch] { + self: BasePythonRunner[_, ColumnarBatch] => + + protected def deserializeColumnarBatch( + batch: ColumnarBatch, + schema: StructType): ColumnarBatch = batch +} From beecc9e66fab4ffeda2f64381297e502532bc0d8 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 14 Sep 2022 10:30:20 +0900 Subject: [PATCH 2/2] reflect feedbacks --- .../spark/sql/execution/python/PythonArrowInput.scala | 11 ++++------- .../sql/execution/python/PythonArrowOutput.scala | 4 ++-- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index 896bc68ac29be..6168d0f867adb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -26,7 +26,6 @@ import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{BasePythonRunner, PythonRDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowWriter -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils @@ -36,8 +35,6 @@ import org.apache.spark.util.Utils * JVM (an iterator of internal rows + additional data if required) to Python (Arrow). */ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => - protected val sqlConf = SQLConf.get - protected val workerConf: Map[String, String] protected val schema: StructType @@ -112,10 +109,10 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In self: BasePythonRunner[Iterator[InternalRow], _] => protected def writeIteratorToArrowStream( - root: VectorSchemaRoot, - writer: ArrowStreamWriter, - dataOut: DataOutputStream, - inputIterator: Iterator[Iterator[InternalRow]]): Unit = { + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[Iterator[InternalRow]]): Unit = { val arrowWriter = ArrowWriter.create(root) while (inputIterator.hasNext) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index 66b8297a5865f..339f114539c28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -123,6 +123,6 @@ private[python] trait BasicPythonArrowOutput extends PythonArrowOutput[ColumnarB self: BasePythonRunner[_, ColumnarBatch] => protected def deserializeColumnarBatch( - batch: ColumnarBatch, - schema: StructType): ColumnarBatch = batch + batch: ColumnarBatch, + schema: StructType): ColumnarBatch = batch }