diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 857c89d7a98f5..e717f4041df10 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -25,6 +25,7 @@ import java.util.Properties import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.util.CompletionIterator /** * A task that sends back the output to the driver application. @@ -87,7 +88,15 @@ private[spark] class ResultTask[T, U]( threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime } else 0L - func(context, rdd.iterator(partition, context)) + val iter = rdd.iterator(partition, context).asInstanceOf[InterruptibleIterator[T]] + val res = func(context, iter) + // SPARK-27568: operations like take() could not consume all elements when func() finished, + // which would lead to readLock on block leaked. So, we manually call completion() for + // those operations here. + if (iter.hasNext && iter.delegate.isInstanceOf[CompletionIterator[T, Iterator[T]]]) { + iter.delegate.asInstanceOf[CompletionIterator[T, Iterator[T]]].completion() + } + res } // This is only callable on the driver side. diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 3490eaf550ce6..4573f82808ac8 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -731,6 +731,17 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu resetSparkContext() } } + + test("SPARK-27568: call take()/first() on a cached rdd should not leak readLock on the block") { + val conf = new SparkConf() + .setAppName("test") + .setMaster("local") + .set("spark.storage.exceptionOnPinLeak", "true") + sc = new SparkContext(conf) + // No exception, no pin leak + assert(sc.parallelize(Range(0, 10), 1).cache().take(1).head === 0) + assert(sc.parallelize(Range(0, 10), 1).cache().first() === 0) + } } object SparkContextSuite {