Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import java.util.Properties
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.util.CompletionIterator

/**
* A task that sends back the output to the driver application.
Expand Down Expand Up @@ -87,7 +88,15 @@ private[spark] class ResultTask[T, U](
threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L

func(context, rdd.iterator(partition, context))
val iter = rdd.iterator(partition, context).asInstanceOf[InterruptibleIterator[T]]
val res = func(context, iter)
// SPARK-27568: operations like take() could not consume all elements when func() finished,
// which would lead to readLock on block leaked. So, we manually call completion() for
// those operations here.
if (iter.hasNext && iter.delegate.isInstanceOf[CompletionIterator[T, Iterator[T]]]) {
iter.delegate.asInstanceOf[CompletionIterator[T, Iterator[T]]].completion()
}
res
}

// This is only callable on the driver side.
Expand Down
11 changes: 11 additions & 0 deletions core/src/test/scala/org/apache/spark/SparkContextSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,17 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
resetSparkContext()
}
}

test("SPARK-27568: call take()/first() on a cached rdd should not leak readLock on the block") {
val conf = new SparkConf()
.setAppName("test")
.setMaster("local")
.set("spark.storage.exceptionOnPinLeak", "true")
sc = new SparkContext(conf)
// No exception, no pin leak
assert(sc.parallelize(Range(0, 10), 1).cache().take(1).head === 0)
assert(sc.parallelize(Range(0, 10), 1).cache().first() === 0)
}
}

object SparkContextSuite {
Expand Down