From 84855986942b410e65fe4e867626c538ff43a574 Mon Sep 17 00:00:00 2001 From: Vlad Rozov Date: Tue, 15 Apr 2025 09:56:34 -0700 Subject: [PATCH 1/4] [SPARK-51821][CORE] Call interrupt() without holding uninterruptibleLock to avoid possible deadlock --- .../spark/util/UninterruptibleThread.scala | 34 +++++++++++++--- .../util/UninterruptibleThreadSuite.scala | 40 +++++++++++++++++++ 2 files changed, 69 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala index 24788d69121b2..99dcb2440d24c 100644 --- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -51,6 +51,12 @@ private[spark] class UninterruptibleThread( @GuardedBy("uninterruptibleLock") private var shouldInterruptThread = false + /** + * Indicates that we should wait for interrupt() call before proceeding. + */ + @GuardedBy("uninterruptibleLock") + private var awaitInterruptThread = false + /** * Run `f` uninterruptibly in `this` thread. The thread won't be interrupted before returning * from `f`. @@ -69,10 +75,22 @@ private[spark] class UninterruptibleThread( } uninterruptibleLock.synchronized { + uninterruptible = true + } + + while (uninterruptibleLock.synchronized { // Clear the interrupted status if it's set. shouldInterruptThread = Thread.interrupted() || shouldInterruptThread - uninterruptible = true + // wait for super.interrupt() to be called + !shouldInterruptThread && awaitInterruptThread }) { + try { + Thread.sleep(100) + } catch { + case _: InterruptedException => + uninterruptibleLock.synchronized { shouldInterruptThread = true } + } } + try { f } finally { @@ -92,11 +110,17 @@ private[spark] class UninterruptibleThread( * interrupted until it enters into the interruptible status. */ override def interrupt(): Unit = { - uninterruptibleLock.synchronized { - if (uninterruptible) { - shouldInterruptThread = true - } else { + if (uninterruptibleLock.synchronized { + shouldInterruptThread = uninterruptible + awaitInterruptThread = !shouldInterruptThread + awaitInterruptThread + }) { + try { super.interrupt() + } finally { + uninterruptibleLock.synchronized { + awaitInterruptThread = false + } } } } diff --git a/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala index 9c0ee1e1303ee..13ecf2b5dc961 100644 --- a/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.util +import java.nio.channels.spi.AbstractInterruptibleChannel import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.util.Random @@ -115,6 +116,45 @@ class UninterruptibleThreadSuite extends SparkFunSuite { assert(interruptStatusBeforeExit) } + test("no runUninterruptibly") { + @volatile var hasInterruptedException = false + val t = new UninterruptibleThread("test") { + override def run(): Unit = { + if (sleep(0)) { + hasInterruptedException = true + } + } + } + t.interrupt() + t.start() + t.join() + assert(hasInterruptedException === true) + } + + test("SPARK-51821 uninterruptibleLock deadlock") { + val latch = new CountDownLatch(1) + val task = new UninterruptibleThread("task thread") { + override def run(): Unit = { + val channel = new AbstractInterruptibleChannel() { + override def implCloseChannel(): Unit = { + begin() + latch.countDown() + try { + Thread.sleep(Long.MaxValue) + } catch { + case _: InterruptedException => Thread.currentThread().interrupt() + } + } + } + channel.close() + } + } + task.start() + assert(latch.await(100, TimeUnit.SECONDS), "await timeout") + task.interrupt() + task.join() + } + test("stress test") { @volatile var hasInterruptedException = false val t = new UninterruptibleThread("test") { From bebf0cb05ef05d39511e6872bc2b745c86c2a8b0 Mon Sep 17 00:00:00 2001 From: Vlad Rozov Date: Wed, 23 Apr 2025 18:43:00 -0700 Subject: [PATCH 2/4] Fixed typo, added comment to clarify how awaitInterruptThread flag is used and handle case where interrupt is called from more than one thread concurrently --- .../spark/util/UninterruptibleThread.scala | 14 ++++++++++++-- .../util/UninterruptibleThreadSuite.scala | 19 +++++++++++++++---- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala index 99dcb2440d24c..2ed62a232bd83 100644 --- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -112,8 +112,18 @@ private[spark] class UninterruptibleThread( override def interrupt(): Unit = { if (uninterruptibleLock.synchronized { shouldInterruptThread = uninterruptible - awaitInterruptThread = !shouldInterruptThread - awaitInterruptThread + // as we are releasing uninterruptibleLock before calling super.interrupt() there is a + // possibility that runUninterruptibly() would be called after lock is released but before + // super.interrupt() is called. In this case to prevent runUninterruptibly() from being + // interrupted, we use awaitInterruptThread flag. We need to set it only if + // runUninterruptibly() is not yet set uninterruptible to true (!shouldInterruptThread) and + // there is no other threads that called interrupt (awaitInterruptThread is already true) + if (!shouldInterruptThread && !awaitInterruptThread) { + awaitInterruptThread = true + true + } else { + false + } }) { try { super.interrupt() diff --git a/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala index 13ecf2b5dc961..fbc954d05af82 100644 --- a/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala @@ -150,7 +150,7 @@ class UninterruptibleThreadSuite extends SparkFunSuite { } } task.start() - assert(latch.await(100, TimeUnit.SECONDS), "await timeout") + assert(latch.await(10, TimeUnit.SECONDS), "await timeout") task.interrupt() task.join() } @@ -188,9 +188,20 @@ class UninterruptibleThreadSuite extends SparkFunSuite { } } t.start() - for (i <- 0 until 400) { - Thread.sleep(Random.nextInt(10)) - t.interrupt() + val threads = new Array[Thread](10) + for (j <- 0 until 10) { + threads(j) = new Thread() { + override def run(): Unit = { + for (i <- 0 until 400) { + Thread.sleep(Random.nextInt(10)) + t.interrupt() + } + } + } + threads(j).start() + } + for (j <- 0 until 10) { + threads(j).join() } t.join() assert(hasInterruptedException === false) From d0be1a3c82176a5c4e86c922b213dc4fcfa08ddc Mon Sep 17 00:00:00 2001 From: Vlad Rozov Date: Mon, 5 May 2025 09:06:20 -0700 Subject: [PATCH 3/4] introduce UninterruptibleLock --- .../spark/util/UninterruptibleThread.scala | 54 +++++++++++-------- 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala index 2ed62a232bd83..50d89cd8346ac 100644 --- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -35,8 +35,38 @@ private[spark] class UninterruptibleThread( this(null, name) } + class UninterruptibleLock { + def awaitInterrupt(): Boolean = synchronized { + // Clear the interrupted status if it's set. + shouldInterruptThread = Thread.interrupted() || shouldInterruptThread + // wait for super.interrupt() to be called + !shouldInterruptThread && awaitInterruptThread + } + + /** + * Is it safe to call [[java.lang.Thread.interrupt()]] and interrupt the current thread + * @return true when there is no concurrent [[runUninterruptibly()]] call ([[uninterruptible]] + * is true) and no concurrent [[interrupt()]] call, otherwise false + */ + def isInterruptible: Boolean = synchronized { + shouldInterruptThread = uninterruptible + // as we are releasing uninterruptibleLock before calling super.interrupt() there is a + // possibility that runUninterruptibly() would be called after lock is released but before + // super.interrupt() is called. In this case to prevent runUninterruptibly() from being + // interrupted, we use awaitInterruptThread flag. We need to set it only if + // runUninterruptibly() is not yet set uninterruptible to true (!shouldInterruptThread) and + // there is no other threads that called interrupt (awaitInterruptThread is already true) + if (!shouldInterruptThread && !awaitInterruptThread) { + awaitInterruptThread = true + true + } else { + false + } + } + } + /** A monitor to protect "uninterruptible" and "interrupted" */ - private val uninterruptibleLock = new Object + private val uninterruptibleLock = new UninterruptibleLock /** * Indicates if `this` thread are in the uninterruptible status. If so, interrupting @@ -78,11 +108,7 @@ private[spark] class UninterruptibleThread( uninterruptible = true } - while (uninterruptibleLock.synchronized { - // Clear the interrupted status if it's set. - shouldInterruptThread = Thread.interrupted() || shouldInterruptThread - // wait for super.interrupt() to be called - !shouldInterruptThread && awaitInterruptThread }) { + while (uninterruptibleLock.awaitInterrupt()) { try { Thread.sleep(100) } catch { @@ -110,21 +136,7 @@ private[spark] class UninterruptibleThread( * interrupted until it enters into the interruptible status. */ override def interrupt(): Unit = { - if (uninterruptibleLock.synchronized { - shouldInterruptThread = uninterruptible - // as we are releasing uninterruptibleLock before calling super.interrupt() there is a - // possibility that runUninterruptibly() would be called after lock is released but before - // super.interrupt() is called. In this case to prevent runUninterruptibly() from being - // interrupted, we use awaitInterruptThread flag. We need to set it only if - // runUninterruptibly() is not yet set uninterruptible to true (!shouldInterruptThread) and - // there is no other threads that called interrupt (awaitInterruptThread is already true) - if (!shouldInterruptThread && !awaitInterruptThread) { - awaitInterruptThread = true - true - } else { - false - } - }) { + if (uninterruptibleLock.isInterruptible) { try { super.interrupt() } finally { From b7e64931e7ff02e3e6b6e09a566d4980e7940db3 Mon Sep 17 00:00:00 2001 From: Vlad Rozov Date: Tue, 6 May 2025 09:18:36 -0700 Subject: [PATCH 4/4] refactored methods to UninterruptibleLock --- .../spark/util/UninterruptibleThread.scala | 99 +++++++++++-------- 1 file changed, 59 insertions(+), 40 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala index 50d89cd8346ac..8fba5ed944c67 100644 --- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -35,14 +35,66 @@ private[spark] class UninterruptibleThread( this(null, name) } - class UninterruptibleLock { - def awaitInterrupt(): Boolean = synchronized { + private class UninterruptibleLock { + /** + * Indicates if `this` thread are in the uninterruptible status. If so, interrupting + * "this" will be deferred until `this` enters into the interruptible status. + */ + @GuardedBy("uninterruptibleLock") + private var uninterruptible = false + + /** + * Indicates if we should interrupt `this` when we are leaving the uninterruptible zone. + */ + @GuardedBy("uninterruptibleLock") + private var shouldInterruptThread = false + + /** + * Indicates that we should wait for interrupt() call before proceeding. + */ + @GuardedBy("uninterruptibleLock") + private var awaitInterruptThread = false + + /** + * Set [[uninterruptible]] to given value and returns the previous value. + */ + def getAndSetUninterruptible(value: Boolean): Boolean = synchronized { + val uninterruptible = this.uninterruptible + this.uninterruptible = value + uninterruptible + } + + def setShouldInterruptThread(value: Boolean): Unit = synchronized { + shouldInterruptThread = value + } + + def setAwaitInterruptThread(value: Boolean): Unit = synchronized { + awaitInterruptThread = value + } + + /** + * Is call to [[java.lang.Thread.interrupt()]] pending + */ + def isInterruptPending: Boolean = synchronized { // Clear the interrupted status if it's set. shouldInterruptThread = Thread.interrupted() || shouldInterruptThread // wait for super.interrupt() to be called !shouldInterruptThread && awaitInterruptThread } + /** + * Set [[uninterruptible]] back to false and call [[java.lang.Thread.interrupt()]] to + * recover interrupt state if necessary + */ + def recoverInterrupt(): Unit = synchronized { + uninterruptible = false + if (shouldInterruptThread) { + shouldInterruptThread = false + // Recover the interrupted status + UninterruptibleThread.super.interrupt() + } + } + /** * Is it safe to call [[java.lang.Thread.interrupt()]] and interrupt the current thread * @return true when there is no concurrent [[runUninterruptibly()]] call ([[uninterruptible]] @@ -68,25 +120,6 @@ private[spark] class UninterruptibleThread( /** A monitor to protect "uninterruptible" and "interrupted" */ private val uninterruptibleLock = new UninterruptibleLock - /** - * Indicates if `this` thread are in the uninterruptible status. If so, interrupting - * "this" will be deferred until `this` enters into the interruptible status. - */ - @GuardedBy("uninterruptibleLock") - private var uninterruptible = false - - /** - * Indicates if we should interrupt `this` when we are leaving the uninterruptible zone. - */ - @GuardedBy("uninterruptibleLock") - private var shouldInterruptThread = false - - /** - * Indicates that we should wait for interrupt() call before proceeding. - */ - @GuardedBy("uninterruptibleLock") - private var awaitInterruptThread = false - /** * Run `f` uninterruptibly in `this` thread. The thread won't be interrupted before returning * from `f`. @@ -99,35 +132,23 @@ private[spark] class UninterruptibleThread( s"Expected: $this but was ${Thread.currentThread()}") } - if (uninterruptibleLock.synchronized { uninterruptible }) { + if (uninterruptibleLock.getAndSetUninterruptible(true)) { // We are already in the uninterruptible status. So just run "f" and return return f } - uninterruptibleLock.synchronized { - uninterruptible = true - } - - while (uninterruptibleLock.awaitInterrupt()) { + while (uninterruptibleLock.isInterruptPending) { try { Thread.sleep(100) } catch { - case _: InterruptedException => - uninterruptibleLock.synchronized { shouldInterruptThread = true } + case _: InterruptedException => uninterruptibleLock.setShouldInterruptThread(true) } } try { f } finally { - uninterruptibleLock.synchronized { - uninterruptible = false - if (shouldInterruptThread) { - // Recover the interrupted status - super.interrupt() - shouldInterruptThread = false - } - } + uninterruptibleLock.recoverInterrupt() } } @@ -140,9 +161,7 @@ private[spark] class UninterruptibleThread( try { super.interrupt() } finally { - uninterruptibleLock.synchronized { - awaitInterruptThread = false - } + uninterruptibleLock.setAwaitInterruptThread(false) } } }