diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b9952270bfc35..8a25311afdc4a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3628,6 +3628,26 @@ def test_udf_in_subquery(self): finally: self.spark.catalog.dropTempView("v") + # SPARK-33277 + def test_udf_with_column_vector(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + + try: + self.spark.range(0, 100000, 1, 1).write.parquet(path) + + def f(x): + return 0 + + fUdf = udf(f, LongType()) + + for offheap in ["true", "false"]: + with self.sql_conf({"spark.sql.columnVector.offheap.enabled": offheap}): + self.assertEquals( + self.spark.read.parquet(path).select(fUdf('id')).head(), Row(0)) + finally: + shutil.rmtree(path) + class HiveSparkSubmitTests(SparkSubmitTests): @@ -5575,6 +5595,28 @@ def test_datasource_with_udf(self): finally: shutil.rmtree(path) + # SPARK-33277 + def test_pandas_udf_with_column_vector(self): + import pandas as pd + from pyspark.sql.functions import pandas_udf + + path = tempfile.mkdtemp() + shutil.rmtree(path) + + try: + self.spark.range(0, 200000, 1, 1).write.parquet(path) + + @pandas_udf(LongType()) + def udf(x): + return pd.Series([0] * len(x)) + + for offheap in ["true", "false"]: + with self.sql_conf({"spark.sql.columnVector.offheap.enabled": offheap}): + self.assertEquals( + self.spark.read.parquet(path).select(udf('id')).head(), Row(0)) + finally: + shutil.rmtree(path) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index 942a6db57416e..293a7c0241e97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -88,6 +88,7 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil inputRDD.mapPartitions { iter => val context = TaskContext.get() + val contextAwareIterator = new ContextAwareIterator(iter, context) // The queue used to buffer input rows so we can drain it to // combine input with output from Python. @@ -119,7 +120,7 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil }) // Add rows to queue to join later with the result. - val projectedRowIter = iter.map { inputRow => + val projectedRowIter = contextAwareIterator.map { inputRow => queue.add(inputRow.asInstanceOf[UnsafeRow]) projection(inputRow) } @@ -136,3 +137,18 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil } } } + +/** + * A TaskContext aware iterator. + * + * As the Python evaluation consumes the parent iterator in a separate thread, + * it could consume more data from the parent even after the task ends and the parent is closed. + * Thus, we should use ContextAwareIterator to stop consuming after the task ends. + */ +class ContextAwareIterator[IN](iter: Iterator[IN], context: TaskContext) extends Iterator[IN] { + + override def hasNext: Boolean = + !context.isCompleted() && !context.isInterrupted() && iter.hasNext + + override def next(): IN = iter.next() +}