diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index ffac330cd6d94..ced66cd0a08f0 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -145,36 +145,10 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte newIter.map { batch: Array[Byte] => (batch, newIter.rowCountInLastBatch) } } - val signal = new Object - val partitions = collection.mutable.Map.empty[Int, Array[Batch]] - - val processPartition = (iter: Iterator[Batch]) => iter.toArray - // This callback is executed by the DAGScheduler thread. - // After fetching a partition, it inserts the partition into the Map, and then - // wakes up the main thread. - val resultHandler = (partitionId: Int, partition: Array[Batch]) => { - signal.synchronized { - partitions(partitionId) = partition - signal.notify() - } - () - } - - spark.sparkContext.runJob(batches, processPartition, resultHandler) - - // The man thread will wait until 0-th partition is available, - // then send it to client and wait for next partition. - var currentPartitionId = 0 - while (currentPartitionId < numPartitions) { - val partition = signal.synchronized { - while (!partitions.contains(currentPartitionId)) { - signal.wait() - } - partitions.remove(currentPartitionId).get - } - - partition.foreach { case (bytes, count) => + def writeBatches(arrowBatches: Array[Batch]): Unit = { + for (arrowBatch <- arrowBatches) { + val (bytes, count) = arrowBatch val response = proto.Response.newBuilder().setClientId(clientId) val batch = proto.Response.ArrowBatch .newBuilder() @@ -185,9 +159,30 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte responseObserver.onNext(response.build()) numSent += 1 } + } + + // Store collection results for worst case of 1 to N-1 partitions + val results = new Array[Array[Batch]](numPartitions - 1) + var lastIndex = -1 // index of last partition written - currentPartitionId += 1 + // Handler to eagerly write partitions in order + val resultHandler = (partitionId: Int, partition: Array[Batch]) => { + // If result is from next partition in order + if (partitionId - 1 == lastIndex) { + writeBatches(partition) + lastIndex += 1 + // Write stored partitions that come next in order + while (lastIndex < results.length && results(lastIndex) != null) { + writeBatches(results(lastIndex)) + results(lastIndex) = null + lastIndex += 1 + } + } else { + // Store partitions received out of order + results(partitionId - 1) = partition + } } + spark.sparkContext.runJob(batches, (iter: Iterator[Batch]) => iter.toArray, resultHandler) } // Make sure at least 1 batch will be sent.