diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py b/python/pyspark/sql/tests/pandas/test_pandas_map.py index 8f7229e1d74fd..b8dfe8314a727 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py @@ -401,6 +401,31 @@ def func(iterator): finally: shutil.rmtree(path) + def test_map_in_pandas_with_barrier_mode(self): + df = self.spark.range(10) + + def func1(iterator): + from pyspark import TaskContext, BarrierTaskContext + + tc = TaskContext.get() + assert tc is not None + assert not isinstance(tc, BarrierTaskContext) + for batch in iterator: + yield batch + + df.mapInPandas(func1, "id long", False).collect() + + def func2(iterator): + from pyspark import TaskContext, BarrierTaskContext + + tc = TaskContext.get() + assert tc is not None + assert isinstance(tc, BarrierTaskContext) + for batch in iterator: + yield batch + + df.mapInPandas(func2, "id long", True).collect() + class MapInPandasTests(ReusedSQLTestCase, MapInPandasTestsMixin): @classmethod diff --git a/python/pyspark/sql/tests/test_arrow_map.py b/python/pyspark/sql/tests/test_arrow_map.py index 15367743585e3..0c2d5f36d0f27 100644 --- a/python/pyspark/sql/tests/test_arrow_map.py +++ b/python/pyspark/sql/tests/test_arrow_map.py @@ -146,6 +146,31 @@ def test_self_join(self): expected = df1.join(df1).collect() self.assertEqual(sorted(actual), sorted(expected)) + def test_map_in_arrow_with_barrier_mode(self): + df = self.spark.range(10) + + def func1(iterator): + from pyspark import TaskContext, BarrierTaskContext + + tc = TaskContext.get() + assert tc is not None + assert not isinstance(tc, BarrierTaskContext) + for batch in iterator: + yield batch + + df.mapInArrow(func1, "id long", False).collect() + + def func2(iterator): + from pyspark import TaskContext, BarrierTaskContext + + tc = TaskContext.get() + assert tc is not None + assert isinstance(tc, BarrierTaskContext) + for batch in iterator: + yield batch + + df.mapInArrow(func2, "id long", True).collect() + class MapInArrowTests(MapInArrowTestsMixin, ReusedSQLTestCase): @classmethod