From 4275f8284d13c49a46a15e48aed08a4114201e7e Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Wed, 31 Jul 2019 22:40:01 +0800 Subject: [PATCH] [SPARK-28153][PYTHON] Use AtomicReference at InputFileBlockHolder (to support input_file_name with Python UDF) This PR proposes to use `AtomicReference` so that parent and child threads can access to the same file block holder. Python UDF expressions are turned to a plan and then it launches a separate thread to consume the input iterator. In the separate child thread, the iterator sets `InputFileBlockHolder.set` before the parent does which the parent thread is unable to read later. 1. In this separate child thread, if it happens to call `InputFileBlockHolder.set` first without initialization of the parent's thread local (which is done when the `ThreadLocal.get()` is first called), the child thread seems calling its own `initialValue` to initialize. 2. After that, the parent calls its own `initialValue` to initializes at the first call of `ThreadLocal.get()`. 3. Both now have two different references. Updating at child isn't reflected to parent. This PR fixes it via initializing parent's thread local with `AtomicReference` for file status so that they can be used in each task, and children thread's update is reflected. I also tried to explain this a bit more at https://github.com/apache/spark/pull/24958#discussion_r297203041. Manually tested and unittest was added. Closes #24958 from HyukjinKwon/SPARK-28153. Authored-by: HyukjinKwon Signed-off-by: Wenchen Fan --- .../spark/rdd/InputFileBlockHolder.scala | 29 ++++++++++++++----- .../org/apache/spark/scheduler/Task.scala | 1 + python/pyspark/sql/tests.py | 8 +++++ 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala index ff2f58d81142d..bfe8152d4dee2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala +++ b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import java.util.concurrent.atomic.AtomicReference + import org.apache.spark.unsafe.types.UTF8String /** @@ -40,26 +42,33 @@ private[spark] object InputFileBlockHolder { /** * The thread variable for the name of the current file being read. This is used by * the InputFileName function in Spark SQL. + * + * @note `inputBlock` works somewhat complicatedly. It guarantees that `initialValue` + * is called at the start of a task. Therefore, one atomic reference is created in the task + * thread. After that, read and write happen to the same atomic reference across the parent and + * children threads. This is in order to support a case where write happens in a child thread + * but read happens at its parent thread, for instance, Python UDF execution. See SPARK-28153. */ - private[this] val inputBlock: InheritableThreadLocal[FileBlock] = - new InheritableThreadLocal[FileBlock] { - override protected def initialValue(): FileBlock = new FileBlock + private[this] val inputBlock: InheritableThreadLocal[AtomicReference[FileBlock]] = + new InheritableThreadLocal[AtomicReference[FileBlock]] { + override protected def initialValue(): AtomicReference[FileBlock] = + new AtomicReference(new FileBlock) } /** * Returns the holding file name or empty string if it is unknown. */ - def getInputFilePath: UTF8String = inputBlock.get().filePath + def getInputFilePath: UTF8String = inputBlock.get().get().filePath /** * Returns the starting offset of the block currently being read, or -1 if it is unknown. */ - def getStartOffset: Long = inputBlock.get().startOffset + def getStartOffset: Long = inputBlock.get().get().startOffset /** * Returns the length of the block being read, or -1 if it is unknown. */ - def getLength: Long = inputBlock.get().length + def getLength: Long = inputBlock.get().get().length /** * Sets the thread-local input block. @@ -68,11 +77,17 @@ private[spark] object InputFileBlockHolder { require(filePath != null, "filePath cannot be null") require(startOffset >= 0, s"startOffset ($startOffset) cannot be negative") require(length >= 0, s"length ($length) cannot be negative") - inputBlock.set(new FileBlock(UTF8String.fromString(filePath), startOffset, length)) + inputBlock.get().set(new FileBlock(UTF8String.fromString(filePath), startOffset, length)) } /** * Clears the input file block to default value. */ def unset(): Unit = inputBlock.remove() + + /** + * Initializes thread local by explicitly getting the value. It triggers ThreadLocal's + * initialValue in the parent thread. + */ + def initialize(): Unit = inputBlock.get() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 182f479fb0dde..daed55cc131c9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -100,6 +100,7 @@ private[spark] abstract class Task[T]( taskContext } + InputFileBlockHolder.initialize() TaskContext.setTaskContext(context) taskThread = Thread.currentThread() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 38e7f39a2711d..f65fe885ef7d7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -847,6 +847,14 @@ def test_input_file_name_reset_for_rdd(self): for result in results: self.assertEqual(result[0], '') + def test_input_file_name_udf(self): + from pyspark.sql.functions import udf, input_file_name + + df = self.spark.read.text('python/test_support/hello/hello.txt') + df = df.select(udf(lambda x: x)("value"), input_file_name().alias('file')) + file_name = df.collect()[0].file + self.assertTrue("python/test_support/hello/hello.txt" in file_name) + def test_udf_defers_judf_initialization(self): # This is separate of UDFInitializationTests # to avoid context initialization