From 39f1a574bd1d0e3eb5c470ef8b0bc58dd204727a Mon Sep 17 00:00:00 2001 From: Avi Minsky Date: Tue, 17 Mar 2026 17:34:52 +0200 Subject: [PATCH 01/11] test pyspark from scala --- .../DataFlintPythonIntegrationSpec.scala | 26 +++++++++++++++++-- .../dataflint_python_exec_integration_test.py | 6 ++++- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala index fcc84fc..e038853 100644 --- a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala @@ -14,15 +14,36 @@ import java.nio.file.Paths * PythonRunner.main() creates a Py4JServer internally and sets PYSPARK_GATEWAY_PORT * for the subprocess — no manual gateway setup needed. The Python script then connects * to this JVM via launch_gateway() and accesses the session through DataFlintStaticSession. + * + * Requires: .venv with pyspark, pandas, pyarrow installed. + * python3 -m venv .venv && .venv/bin/pip install pyspark pandas pyarrow */ class DataFlintPythonIntegrationSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll { + // pluginspark3 tests run with CWD = spark-plugin/pluginspark3/, so go up one level + // to reach the project root where .venv and pyspark-testing live. + private val projectRoot = Paths.get("").toAbsolutePath.getParent + + private val venvPython: String = { + val p = projectRoot.resolve(Paths.get(".venv", "bin", "python3")) + require(p.toFile.exists(), + s"Python venv not found at $p\n" + + "Run: python3 -m venv .venv && .venv/bin/pip install pyspark pandas pyarrow") + p.toString + } + private var spark: SparkSession = _ override def beforeAll(): Unit = { + // Set before SparkSession creation so the conf is also applied to UDF workers. + // Also sets it as a system property so PythonRunner's internal `new SparkConf()` + // picks it up when resolving the Python executable for the subprocess. + System.setProperty("spark.pyspark.python", venvPython) + spark = SparkSession.builder() .master("local[2]") .appName("DataFlintPythonIntegrationSpec") + .config("spark.pyspark.python", venvPython) .config(DataflintSparkUICommonLoader.INSTRUMENT_SPARK_ENABLED, "true") .config(DataflintSparkUICommonLoader.INSTRUMENT_ARROW_EVAL_PYTHON_ENABLED, "true") .config(DataflintSparkUICommonLoader.INSTRUMENT_BATCH_EVAL_PYTHON_ENABLED, "true") @@ -37,12 +58,13 @@ class DataFlintPythonIntegrationSpec extends AnyFunSuite with Matchers with Befo override def afterAll(): Unit = { DataFlintStaticSession.clear() + System.clearProperty("spark.pyspark.python") if (spark != null) spark.stop() } test("all 4 DataFlint Python exec nodes are instrumented and visible in the plan") { - val scriptPath = Paths.get("pyspark-testing", "dataflint_python_exec_integration_test.py") - .toAbsolutePath.toString + val scriptPath = projectRoot.resolve( + Paths.get("pyspark-testing", "dataflint_python_exec_integration_test.py")).toString // PythonRunner sets PYSPARK_GATEWAY_PORT + PYSPARK_GATEWAY_SECRET for the subprocess, // wires up PYTHONPATH (pyspark + py4j), and throws SparkException on non-zero exit. PythonRunner.main(Array(scriptPath, "")) diff --git a/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py b/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py index 2fe81ba..8004c8c 100644 --- a/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py +++ b/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py @@ -28,7 +28,11 @@ jsc = static.javaSparkContext() spark_jvm = static.session() -sc = pyspark.SparkContext(gateway=gateway, jsc=jsc) +# Build SparkConf from the existing JavaSparkContext so spark.master (and all other +# existing configs) are present — without this PySpark raises MASTER_URL_NOT_SET. +conf = pyspark.conf.SparkConf(True, gateway.jvm, jsc.getConf()) + +sc = pyspark.SparkContext(gateway=gateway, jsc=jsc, conf=conf) spark = SparkSession(sc, jsparkSession=spark_jvm) print(f"Connected to Spark {spark.version}") From ecb893c853a26d8c18ca7e0f65f4bb8b6cf8faf1 Mon Sep 17 00:00:00 2001 From: Avi Minsky Date: Tue, 17 Mar 2026 20:06:12 +0200 Subject: [PATCH 02/11] working example of py test from scala --- .../DataFlintPythonIntegrationSpec.scala | 18 ++- .../dataflint_python_exec_integration_test.py | 111 ++++-------------- 2 files changed, 38 insertions(+), 91 deletions(-) diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala index e038853..d0268ea 100644 --- a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala @@ -22,7 +22,7 @@ class DataFlintPythonIntegrationSpec extends AnyFunSuite with Matchers with Befo // pluginspark3 tests run with CWD = spark-plugin/pluginspark3/, so go up one level // to reach the project root where .venv and pyspark-testing live. - private val projectRoot = Paths.get("").toAbsolutePath.getParent + private val projectRoot = Paths.get("").toAbsolutePath//.getParent private val venvPython: String = { val p = projectRoot.resolve(Paths.get(".venv", "bin", "python3")) @@ -49,6 +49,7 @@ class DataFlintPythonIntegrationSpec extends AnyFunSuite with Matchers with Befo .config(DataflintSparkUICommonLoader.INSTRUMENT_BATCH_EVAL_PYTHON_ENABLED, "true") .config(DataflintSparkUICommonLoader.INSTRUMENT_FLAT_MAP_GROUPS_PANDAS_ENABLED, "true") .config(DataflintSparkUICommonLoader.INSTRUMENT_FLAT_MAP_COGROUPS_PANDAS_ENABLED, "true") +// .config("spark.sql.adaptive.enabled", "false") .config("spark.ui.enabled", "false") .withExtensions(new DataFlintInstrumentationExtension) .getOrCreate() @@ -67,6 +68,21 @@ class DataFlintPythonIntegrationSpec extends AnyFunSuite with Matchers with Befo Paths.get("pyspark-testing", "dataflint_python_exec_integration_test.py")).toString // PythonRunner sets PYSPARK_GATEWAY_PORT + PYSPARK_GATEWAY_SECRET for the subprocess, // wires up PYTHONPATH (pyspark + py4j), and throws SparkException on non-zero exit. + // The script registers 4 temp views; we check their physical plans here on the Scala side. PythonRunner.main(Array(scriptPath, "")) + + val checks = Seq( + "batch_eval_view" -> "DataFlintBatchEvalPython", + "arrow_eval_view" -> "DataFlintArrowEvalPython", + "flat_map_groups_view" -> "DataFlintFlatMapGroupsInPandas", + "flat_map_cogroups_view" -> "DataFlintFlatMapCoGroupsInPandas", + ) + + checks.foreach { case (view, expectedNode) => + val df = spark.table(view) + df.collect() + val plan = df.queryExecution.executedPlan.toString + plan should include(expectedNode) + } } } \ No newline at end of file diff --git a/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py b/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py index 8004c8c..f3da468 100644 --- a/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py +++ b/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py @@ -1,16 +1,15 @@ """ -Integration test: verifies all 4 DataFlint Python exec nodes are instrumented. +Integration test helper: registers 4 DataFlint Python exec node scenarios as temp views. -Connects to the Scala test's SparkSession via the Py4J gateway that PythonRunner -sets up automatically (PYSPARK_GATEWAY_PORT / PYSPARK_GATEWAY_SECRET env vars). +The Scala test (DataFlintPythonIntegrationSpec) connects via PythonRunner / Py4J, +calls this script, then checks each view's executedPlan for the DataFlint node. -Tested nodes: - BatchEvalPython -> DataFlintBatchEvalPython (@udf) - ArrowEvalPython -> DataFlintArrowEvalPython (@pandas_udf scalar) - FlatMapGroupsInPandas -> DataFlintFlatMapGroupsInPandas (applyInPandas) - FlatMapCoGroupsInPandas-> DataFlintFlatMapCoGroupsInPandas (cogroup applyInPandas) +Registered views: + batch_eval_view — @udf → DataFlintBatchEvalPython + arrow_eval_view — @pandas_udf scalar → DataFlintArrowEvalPython + flat_map_groups_view — applyInPandas → DataFlintFlatMapGroupsInPandas + flat_map_cogroups_view — cogroup.applyIn… → DataFlintFlatMapCoGroupsInPandas """ -import sys import pyspark import pyspark.java_gateway from pyspark.sql import SparkSession @@ -30,89 +29,38 @@ # Build SparkConf from the existing JavaSparkContext so spark.master (and all other # existing configs) are present — without this PySpark raises MASTER_URL_NOT_SET. -conf = pyspark.conf.SparkConf(True, gateway.jvm, jsc.getConf()) - +conf = pyspark.conf.SparkConf(True, gateway.jvm, jsc.getConf()) sc = pyspark.SparkContext(gateway=gateway, jsc=jsc, conf=conf) spark = SparkSession(sc, jsparkSession=spark_jvm) -print(f"Connected to Spark {spark.version}") - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- -failures = [] - -def executed_plan_str(df): - """Return the full executedPlan string after triggering execution.""" - df.collect() - return df._jdf.queryExecution().executedPlan().toString() - -def assert_dataflint_node(df, expected_node, test_name): - plan = executed_plan_str(df) - if expected_node in plan: - print(f" PASS [{test_name}] — {expected_node} found in plan") - else: - msg = f"[{test_name}] '{expected_node}' not found in plan:\n{plan}" - failures.append(msg) - print(f" FAIL {msg}") - -# --------------------------------------------------------------------------- -# Sample data -# --------------------------------------------------------------------------- df = spark.createDataFrame( [(1, "a"), (2, "b"), (3, "a"), (4, "b")], ["id", "cat"] ) -# --------------------------------------------------------------------------- -# Test 1 — @udf → BatchEvalPython → DataFlintBatchEvalPython -# --------------------------------------------------------------------------- -print("\n[1] BatchEvalPython (@udf)") - +# 1 — BatchEvalPython @udf(returnType=LongType()) def double_udf(x): return x * 2 -assert_dataflint_node( - df.select(double_udf("id")), - "DataFlintBatchEvalPython", - "BatchEvalPython", -) - -# --------------------------------------------------------------------------- -# Test 2 — @pandas_udf scalar → ArrowEvalPython → DataFlintArrowEvalPython -# --------------------------------------------------------------------------- -print("\n[2] ArrowEvalPython (@pandas_udf scalar)") +df.select(double_udf("id")).createOrReplaceTempView("batch_eval_view") +# 2 — ArrowEvalPython @pandas_udf(LongType()) def double_pandas_udf(s: pd.Series) -> pd.Series: return s * 2 -assert_dataflint_node( - df.select(double_pandas_udf("id")), - "DataFlintArrowEvalPython", - "ArrowEvalPython", -) - -# --------------------------------------------------------------------------- -# Test 3 — applyInPandas → FlatMapGroupsInPandas → DataFlintFlatMapGroupsInPandas -# --------------------------------------------------------------------------- -print("\n[3] FlatMapGroupsInPandas (applyInPandas)") +df.select(double_pandas_udf("id")).createOrReplaceTempView("arrow_eval_view") +# 3 — FlatMapGroupsInPandas def identity_group(key, pdf): return pdf -assert_dataflint_node( - df.groupby("cat").applyInPandas(identity_group, schema="id long, cat string"), - "DataFlintFlatMapGroupsInPandas", - "FlatMapGroupsInPandas", -) - -# --------------------------------------------------------------------------- -# Test 4 — cogroup applyInPandas → FlatMapCoGroupsInPandas → DataFlintFlatMapCoGroupsInPandas -# --------------------------------------------------------------------------- -print("\n[4] FlatMapCoGroupsInPandas (cogroup applyInPandas)") +df.groupby("cat").applyInPandas( + identity_group, schema="id long, cat string" +).createOrReplaceTempView("flat_map_groups_view") +# 4 — FlatMapCoGroupsInPandas df2 = spark.createDataFrame( [(1, "x"), (2, "y"), (3, "z"), (4, "w")], ["id", "label"] @@ -123,23 +71,6 @@ def cogroup_fn(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: left["label"] = right["label"].values[0] if len(right) > 0 else None return left[["id", "cat", "label"]] -assert_dataflint_node( - df.groupby("id").cogroup(df2.groupby("id")).applyInPandas( - cogroup_fn, schema="id long, cat string, label string" - ), - "DataFlintFlatMapCoGroupsInPandas", - "FlatMapCoGroupsInPandas", -) - -# --------------------------------------------------------------------------- -# Summary -# --------------------------------------------------------------------------- -print("\n" + "=" * 60) -if failures: - print(f"FAILED — {len(failures)} test(s) did not find the expected DataFlint node:") - for f in failures: - print(f" • {f}") - sys.exit(1) -else: - print("ALL PASSED — all 4 DataFlint Python exec nodes are instrumented") - sys.exit(0) \ No newline at end of file +df.groupby("id").cogroup(df2.groupby("id")).applyInPandas( + cogroup_fn, schema="id long, cat string, label string" +).createOrReplaceTempView("flat_map_cogroups_view") \ No newline at end of file From c0f38fab4c60c51dc46f897a815c8ef7d5636e1b Mon Sep 17 00:00:00 2001 From: Avi Minsky Date: Wed, 18 Mar 2026 14:15:42 +0200 Subject: [PATCH 03/11] clean up for test code --- .../dataflint/DataFlintPythonExecSpec.scala | 2 +- .../DataFlintPythonIntegrationSpec.scala | 64 +++++++------ .../dataflint/DataFlintWindowExecSpec.scala | 10 +- .../spark/dataflint/SqlMetricTestHelper.scala | 24 ----- .../dataflint_python_exec_integration_test.py | 95 ++++++++++--------- 5 files changed, 91 insertions(+), 104 deletions(-) delete mode 100644 spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/SqlMetricTestHelper.scala diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonExecSpec.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonExecSpec.scala index 7f3e5f9..f3e3fba 100644 --- a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonExecSpec.scala +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonExecSpec.scala @@ -10,7 +10,7 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers -class DataFlintPythonExecSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll { +class DataFlintPythonExecSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll with DataFlintTestHelper { private var spark: SparkSession = _ diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala index d0268ea..49af685 100644 --- a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala @@ -2,6 +2,7 @@ package org.apache.spark.dataflint import org.apache.spark.deploy.PythonRunner import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.SparkPlan import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers @@ -18,7 +19,7 @@ import java.nio.file.Paths * Requires: .venv with pyspark, pandas, pyarrow installed. * python3 -m venv .venv && .venv/bin/pip install pyspark pandas pyarrow */ -class DataFlintPythonIntegrationSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll { +class DataFlintPythonIntegrationSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll with DataFlintTestHelper { // pluginspark3 tests run with CWD = spark-plugin/pluginspark3/, so go up one level // to reach the project root where .venv and pyspark-testing live. @@ -32,6 +33,10 @@ class DataFlintPythonIntegrationSpec extends AnyFunSuite with Matchers with Befo p.toString } + private val scriptPath: String = + projectRoot.resolve( + Paths.get("pyspark-testing", "dataflint_python_exec_integration_test.py")).toString + private var spark: SparkSession = _ override def beforeAll(): Unit = { @@ -43,13 +48,12 @@ class DataFlintPythonIntegrationSpec extends AnyFunSuite with Matchers with Befo spark = SparkSession.builder() .master("local[2]") .appName("DataFlintPythonIntegrationSpec") - .config("spark.pyspark.python", venvPython) - .config(DataflintSparkUICommonLoader.INSTRUMENT_SPARK_ENABLED, "true") - .config(DataflintSparkUICommonLoader.INSTRUMENT_ARROW_EVAL_PYTHON_ENABLED, "true") - .config(DataflintSparkUICommonLoader.INSTRUMENT_BATCH_EVAL_PYTHON_ENABLED, "true") + .config("spark.pyspark.python", venvPython) + .config(DataflintSparkUICommonLoader.INSTRUMENT_SPARK_ENABLED, "true") + .config(DataflintSparkUICommonLoader.INSTRUMENT_ARROW_EVAL_PYTHON_ENABLED, "true") + .config(DataflintSparkUICommonLoader.INSTRUMENT_BATCH_EVAL_PYTHON_ENABLED, "true") .config(DataflintSparkUICommonLoader.INSTRUMENT_FLAT_MAP_GROUPS_PANDAS_ENABLED, "true") .config(DataflintSparkUICommonLoader.INSTRUMENT_FLAT_MAP_COGROUPS_PANDAS_ENABLED, "true") -// .config("spark.sql.adaptive.enabled", "false") .config("spark.ui.enabled", "false") .withExtensions(new DataFlintInstrumentationExtension) .getOrCreate() @@ -63,26 +67,32 @@ class DataFlintPythonIntegrationSpec extends AnyFunSuite with Matchers with Befo if (spark != null) spark.stop() } - test("all 4 DataFlint Python exec nodes are instrumented and visible in the plan") { - val scriptPath = projectRoot.resolve( - Paths.get("pyspark-testing", "dataflint_python_exec_integration_test.py")).toString - // PythonRunner sets PYSPARK_GATEWAY_PORT + PYSPARK_GATEWAY_SECRET for the subprocess, - // wires up PYTHONPATH (pyspark + py4j), and throws SparkException on non-zero exit. - // The script registers 4 temp views; we check their physical plans here on the Scala side. - PythonRunner.main(Array(scriptPath, "")) - - val checks = Seq( - "batch_eval_view" -> "DataFlintBatchEvalPython", - "arrow_eval_view" -> "DataFlintArrowEvalPython", - "flat_map_groups_view" -> "DataFlintFlatMapGroupsInPandas", - "flat_map_cogroups_view" -> "DataFlintFlatMapCoGroupsInPandas", - ) - - checks.foreach { case (view, expectedNode) => - val df = spark.table(view) - df.collect() - val plan = df.queryExecution.executedPlan.toString - plan should include(expectedNode) - } + /** Run the Python script for the given test name, then check the registered view. */ + private def assertPythonNode(testName: String, view: String, expectedNode: String): Unit = { + PythonRunner.main(Array(scriptPath, "", testName)) + val df = spark.table(view) + df.collect() + val node = finalPlan(df).collect { + case n: SparkPlan if n.getClass.getSimpleName.contains(expectedNode) => n + }.headOption + node shouldBe defined + node.get.metrics should contain key "duration" + node.get.metrics("duration").value should be > 0L + } + + test("BatchEvalPython (@udf) is instrumented") { + assertPythonNode("batch_eval", "batch_eval_view", "DataFlintBatchEvalPython") + } + + test("ArrowEvalPython (@pandas_udf scalar) is instrumented") { + assertPythonNode("arrow_eval", "arrow_eval_view", "DataFlintArrowEvalPython") + } + + test("FlatMapGroupsInPandas (applyInPandas) is instrumented") { + assertPythonNode("flat_map_groups", "flat_map_groups_view", "DataFlintFlatMapGroupsInPandas") + } + + test("FlatMapCoGroupsInPandas (cogroup applyInPandas) is instrumented") { + assertPythonNode("flat_map_cogroups", "flat_map_cogroups_view", "DataFlintFlatMapCoGroupsInPandas") } } \ No newline at end of file diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintWindowExecSpec.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintWindowExecSpec.scala index ded8182..41fb52d 100644 --- a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintWindowExecSpec.scala +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintWindowExecSpec.scala @@ -2,7 +2,6 @@ package org.apache.spark.dataflint import org.apache.spark.sql.execution.{ExplicitRepartitionExtension, ExplicitRepartitionOps} import org.apache.spark.sql.{DataFrame, Encoder, Encoders, SparkSession} -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.window.DataFlintWindowExec import org.apache.spark.sql.expressions.{Aggregator, Window} import org.apache.spark.sql.functions.{col, rank, udaf} @@ -28,7 +27,7 @@ private class SlowSumAggregator(fromSleep: Long, toSleep: Long) extends Aggregat def outputEncoder: Encoder[Long] = Encoders.scalaLong } -class DataFlintWindowExecSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll with SqlMetricTestHelper { +class DataFlintWindowExecSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll with DataFlintTestHelper { private var spark: SparkSession = _ @@ -52,13 +51,6 @@ class DataFlintWindowExecSpec extends AnyFunSuite with Matchers with BeforeAndAf if (spark != null) spark.stop() } - // With AQE, executedPlan is AdaptiveSparkPlanExec. After collect(), finalPhysicalPlan holds - // the fully optimised plan. For plan-structure tests that don't execute the query, use - // queryExecution.sparkPlan instead (our strategy runs before AQE wraps the plan). - private def finalPlan(df: DataFrame) = df.queryExecution.executedPlan match { - case aqe: AdaptiveSparkPlanExec => aqe.finalPhysicalPlan - case p => p - } test("DataFlintWindowPlannerStrategy replaces WindowExec with DataFlintWindowExec for SQL window") { val session = spark diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/SqlMetricTestHelper.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/SqlMetricTestHelper.scala deleted file mode 100644 index 88e6578..0000000 --- a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/SqlMetricTestHelper.scala +++ /dev/null @@ -1,24 +0,0 @@ -package org.apache.spark.dataflint - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.metric.SQLMetric -import org.scalatest.Assertions - -case class MetricStats(total: Long, min: Long, med: Long, max: Long) - -trait SqlMetricTestHelper extends Assertions { - - // Reads per-task total/min/med/max from the SQL status store for the given SQLMetric. - // createTimingMetric records each task's elapsed time individually; the store formats them as: - // "total (min, med, max (stageId: taskId))\nX ms (Y ms, Z ms, W ms (stage A.B: task C))" - def metricMinMax(metric: SQLMetric)(implicit spark: SparkSession): MetricStats = { - val sqlStore = spark.sharedState.statusStore - val execData = sqlStore.executionsList().maxBy(_.executionId) - val metricStr = sqlStore.executionMetrics(execData.executionId).getOrElse(metric.id, "") - val pattern = """(\d+) ms \((\d+) ms, (\d+) ms, (\d+) ms""".r - pattern.findFirstMatchIn(metricStr) match { - case Some(m) => MetricStats(m.group(1).toLong, m.group(2).toLong, m.group(3).toLong, m.group(4).toLong) - case None => fail(s"Expected per-task timing breakdown but got: '$metricStr'") - } - } -} \ No newline at end of file diff --git a/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py b/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py index f3da468..468e8b7 100644 --- a/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py +++ b/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py @@ -1,15 +1,16 @@ """ -Integration test helper: registers 4 DataFlint Python exec node scenarios as temp views. +Integration test helper: registers one DataFlint Python exec node scenario as a temp view. -The Scala test (DataFlintPythonIntegrationSpec) connects via PythonRunner / Py4J, -calls this script, then checks each view's executedPlan for the DataFlint node. +Usage: pass the test name as the first argument (sys.argv[1]): + batch_eval — @udf → registers batch_eval_view + arrow_eval — @pandas_udf scalar → registers arrow_eval_view + flat_map_groups — applyInPandas → registers flat_map_groups_view + flat_map_cogroups— cogroup.applyIn… → registers flat_map_cogroups_view -Registered views: - batch_eval_view — @udf → DataFlintBatchEvalPython - arrow_eval_view — @pandas_udf scalar → DataFlintArrowEvalPython - flat_map_groups_view — applyInPandas → DataFlintFlatMapGroupsInPandas - flat_map_cogroups_view — cogroup.applyIn… → DataFlintFlatMapCoGroupsInPandas +The Scala test (DataFlintPythonIntegrationSpec) calls PythonRunner.main with one of +the above names, then checks the corresponding view's executedPlan. """ +import sys import pyspark import pyspark.java_gateway from pyspark.sql import SparkSession @@ -17,18 +18,13 @@ from pyspark.sql.types import LongType import pandas as pd -# --------------------------------------------------------------------------- -# Connect to the Scala test's SparkSession via the existing Py4J gateway. -# PythonRunner has already set PYSPARK_GATEWAY_PORT in this process's env, -# so launch_gateway() connects to the running JVM instead of launching a new one. -# --------------------------------------------------------------------------- +test_name = sys.argv[1] if len(sys.argv) > 1 else "" + gateway = pyspark.java_gateway.launch_gateway() static = gateway.jvm.org.apache.spark.dataflint.DataFlintStaticSession jsc = static.javaSparkContext() spark_jvm = static.session() -# Build SparkConf from the existing JavaSparkContext so spark.master (and all other -# existing configs) are present — without this PySpark raises MASTER_URL_NOT_SET. conf = pyspark.conf.SparkConf(True, gateway.jvm, jsc.getConf()) sc = pyspark.SparkContext(gateway=gateway, jsc=jsc, conf=conf) spark = SparkSession(sc, jsparkSession=spark_jvm) @@ -38,39 +34,52 @@ ["id", "cat"] ) -# 1 — BatchEvalPython -@udf(returnType=LongType()) -def double_udf(x): - return x * 2 -df.select(double_udf("id")).createOrReplaceTempView("batch_eval_view") +def register_batch_eval(): + @udf(returnType=LongType()) + def double_udf(x): + return x * 2 + df.select(double_udf("id")).createOrReplaceTempView("batch_eval_view") -# 2 — ArrowEvalPython -@pandas_udf(LongType()) -def double_pandas_udf(s: pd.Series) -> pd.Series: - return s * 2 -df.select(double_pandas_udf("id")).createOrReplaceTempView("arrow_eval_view") +def register_arrow_eval(): + @pandas_udf(LongType()) + def double_pandas_udf(s: pd.Series) -> pd.Series: + return s * 2 + df.select(double_pandas_udf("id")).createOrReplaceTempView("arrow_eval_view") -# 3 — FlatMapGroupsInPandas -def identity_group(key, pdf): - return pdf -df.groupby("cat").applyInPandas( - identity_group, schema="id long, cat string" -).createOrReplaceTempView("flat_map_groups_view") +def register_flat_map_groups(): + def identity_group(key, pdf): + return pdf + df.groupby("cat").applyInPandas( + identity_group, schema="id long, cat string" + ).createOrReplaceTempView("flat_map_groups_view") -# 4 — FlatMapCoGroupsInPandas -df2 = spark.createDataFrame( - [(1, "x"), (2, "y"), (3, "z"), (4, "w")], - ["id", "label"] -) -def cogroup_fn(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: - left = left.copy() - left["label"] = right["label"].values[0] if len(right) > 0 else None - return left[["id", "cat", "label"]] +def register_flat_map_cogroups(): + df2 = spark.createDataFrame( + [(1, "x"), (2, "y"), (3, "z"), (4, "w")], + ["id", "label"] + ) + def cogroup_fn(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: + left = left.copy() + left["label"] = right["label"].values[0] if len(right) > 0 else None + return left[["id", "cat", "label"]] + df.groupby("id").cogroup(df2.groupby("id")).applyInPandas( + cogroup_fn, schema="id long, cat string, label string" + ).createOrReplaceTempView("flat_map_cogroups_view") + + +_tests = { + "batch_eval": register_batch_eval, + "arrow_eval": register_arrow_eval, + "flat_map_groups": register_flat_map_groups, + "flat_map_cogroups": register_flat_map_cogroups, +} + +if test_name not in _tests: + print(f"Unknown test '{test_name}'. Valid options: {list(_tests)}", file=sys.stderr) + sys.exit(1) -df.groupby("id").cogroup(df2.groupby("id")).applyInPandas( - cogroup_fn, schema="id long, cat string, label string" -).createOrReplaceTempView("flat_map_cogroups_view") \ No newline at end of file +_tests[test_name]() \ No newline at end of file From 9a623b84b5716b5d26732fce0aebac6a191b18b2 Mon Sep 17 00:00:00 2001 From: Avi Minsky Date: Wed, 18 Mar 2026 18:27:49 +0200 Subject: [PATCH 04/11] added missing DataFlintTestHelper.scala --- .../spark/dataflint/DataFlintTestHelper.scala | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintTestHelper.scala diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintTestHelper.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintTestHelper.scala new file mode 100644 index 0000000..58b50b0 --- /dev/null +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintTestHelper.scala @@ -0,0 +1,34 @@ +package org.apache.spark.dataflint + +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.metric.SQLMetric +import org.scalatest.Assertions + +case class MetricStats(total: Long, min: Long, med: Long, max: Long) + +trait DataFlintTestHelper extends Assertions { + + // With AQE, executedPlan is AdaptiveSparkPlanExec. After collect(), finalPhysicalPlan holds + // the fully optimised plan. For plan-structure tests that don't execute the query, use + // queryExecution.sparkPlan instead (our strategy runs before AQE wraps the plan). + def finalPlan(df: DataFrame): SparkPlan = df.queryExecution.executedPlan match { + case aqe: AdaptiveSparkPlanExec => aqe.finalPhysicalPlan + case p => p + } + + // Reads per-task total/min/med/max from the SQL status store for the given SQLMetric. + // createTimingMetric records each task's elapsed time individually; the store formats them as: + // "total (min, med, max (stageId: taskId))\nX ms (Y ms, Z ms, W ms (stage A.B: task C))" + def metricMinMax(metric: SQLMetric)(implicit spark: SparkSession): MetricStats = { + val sqlStore = spark.sharedState.statusStore + val execData = sqlStore.executionsList().maxBy(_.executionId) + val metricStr = sqlStore.executionMetrics(execData.executionId).getOrElse(metric.id, "") + val pattern = """(\d+) ms \((\d+) ms, (\d+) ms, (\d+) ms""".r + pattern.findFirstMatchIn(metricStr) match { + case Some(m) => MetricStats(m.group(1).toLong, m.group(2).toLong, m.group(3).toLong, m.group(4).toLong) + case None => fail(s"Expected per-task timing breakdown but got: '$metricStr'") + } + } +} \ No newline at end of file From 7e1593f8af5c2d3b62adc0bfd55483f7232845e1 Mon Sep 17 00:00:00 2001 From: Avi Minsky Date: Sun, 22 Mar 2026 11:49:54 +0200 Subject: [PATCH 05/11] added py tests --- .../dataflint_python_exec_integration_test.py | 85 ------------------- 1 file changed, 85 deletions(-) delete mode 100644 spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py diff --git a/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py b/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py deleted file mode 100644 index 468e8b7..0000000 --- a/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py +++ /dev/null @@ -1,85 +0,0 @@ -""" -Integration test helper: registers one DataFlint Python exec node scenario as a temp view. - -Usage: pass the test name as the first argument (sys.argv[1]): - batch_eval — @udf → registers batch_eval_view - arrow_eval — @pandas_udf scalar → registers arrow_eval_view - flat_map_groups — applyInPandas → registers flat_map_groups_view - flat_map_cogroups— cogroup.applyIn… → registers flat_map_cogroups_view - -The Scala test (DataFlintPythonIntegrationSpec) calls PythonRunner.main with one of -the above names, then checks the corresponding view's executedPlan. -""" -import sys -import pyspark -import pyspark.java_gateway -from pyspark.sql import SparkSession -from pyspark.sql.functions import udf, pandas_udf -from pyspark.sql.types import LongType -import pandas as pd - -test_name = sys.argv[1] if len(sys.argv) > 1 else "" - -gateway = pyspark.java_gateway.launch_gateway() -static = gateway.jvm.org.apache.spark.dataflint.DataFlintStaticSession -jsc = static.javaSparkContext() -spark_jvm = static.session() - -conf = pyspark.conf.SparkConf(True, gateway.jvm, jsc.getConf()) -sc = pyspark.SparkContext(gateway=gateway, jsc=jsc, conf=conf) -spark = SparkSession(sc, jsparkSession=spark_jvm) - -df = spark.createDataFrame( - [(1, "a"), (2, "b"), (3, "a"), (4, "b")], - ["id", "cat"] -) - - -def register_batch_eval(): - @udf(returnType=LongType()) - def double_udf(x): - return x * 2 - df.select(double_udf("id")).createOrReplaceTempView("batch_eval_view") - - -def register_arrow_eval(): - @pandas_udf(LongType()) - def double_pandas_udf(s: pd.Series) -> pd.Series: - return s * 2 - df.select(double_pandas_udf("id")).createOrReplaceTempView("arrow_eval_view") - - -def register_flat_map_groups(): - def identity_group(key, pdf): - return pdf - df.groupby("cat").applyInPandas( - identity_group, schema="id long, cat string" - ).createOrReplaceTempView("flat_map_groups_view") - - -def register_flat_map_cogroups(): - df2 = spark.createDataFrame( - [(1, "x"), (2, "y"), (3, "z"), (4, "w")], - ["id", "label"] - ) - def cogroup_fn(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: - left = left.copy() - left["label"] = right["label"].values[0] if len(right) > 0 else None - return left[["id", "cat", "label"]] - df.groupby("id").cogroup(df2.groupby("id")).applyInPandas( - cogroup_fn, schema="id long, cat string, label string" - ).createOrReplaceTempView("flat_map_cogroups_view") - - -_tests = { - "batch_eval": register_batch_eval, - "arrow_eval": register_arrow_eval, - "flat_map_groups": register_flat_map_groups, - "flat_map_cogroups": register_flat_map_cogroups, -} - -if test_name not in _tests: - print(f"Unknown test '{test_name}'. Valid options: {list(_tests)}", file=sys.stderr) - sys.exit(1) - -_tests[test_name]() \ No newline at end of file From 724b755d595b1a038025bfd7b09a7551c68ddf8d Mon Sep 17 00:00:00 2001 From: Avi Minsky Date: Sun, 22 Mar 2026 15:04:40 +0200 Subject: [PATCH 06/11] now installing venv as part of sbt test --- spark-plugin/build.sbt | 37 ++++++-- .../DataFlintPythonIntegrationSpec.scala | 17 +--- .../dataflint_python_exec_integration_test.py | 85 +++++++++++++++++++ 3 files changed, 119 insertions(+), 20 deletions(-) create mode 100644 spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py diff --git a/spark-plugin/build.sbt b/spark-plugin/build.sbt index 9261c0e..1847992 100644 --- a/spark-plugin/build.sbt +++ b/spark-plugin/build.sbt @@ -1,7 +1,12 @@ import xerial.sbt.Sonatype._ import sbtassembly.AssemblyPlugin.autoImport._ +import scala.sys.process._ lazy val versionNum: String = "0.8.8" +lazy val spark3Version: String = "3.5.1" +lazy val pythonVenvDir: String = System.getProperty("java.io.tmpdir") + "/dataflint-pyspark-venv" +lazy val pythonExec: String = "python3.11" +val createPythonVenv = taskKey[Unit]("Create Python venv and install pyspark test dependencies") lazy val scala212 = "2.12.20" lazy val scala213 = "2.13.16" lazy val supportedScalaVersions = List(scala212, scala213) @@ -35,8 +40,8 @@ lazy val plugin = (project in file("plugin")) } else { versionNum + "-SNAPSHOT" }), - libraryDependencies += "org.apache.spark" %% "spark-core" % "3.5.1" % "provided", - libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.5.1" % "provided", + libraryDependencies += "org.apache.spark" %% "spark-core" % spark3Version % "provided", + libraryDependencies += "org.apache.spark" %% "spark-sql" % spark3Version % "provided", libraryDependencies += "com.amazonaws" % "aws-java-sdk-s3" % "1.12.470" % "provided", libraryDependencies += "org.apache.iceberg" %% "iceberg-spark-runtime-3.5" % "1.5.0" % "provided", libraryDependencies += "io.delta" %% "delta-spark" % "3.2.0" % "provided", @@ -61,12 +66,12 @@ lazy val pluginspark3 = (project in file("pluginspark3")) } else { versionNum + "-SNAPSHOT" }), - libraryDependencies += "org.apache.spark" %% "spark-core" % "3.5.1" % "provided", - libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.5.1" % "provided", + libraryDependencies += "org.apache.spark" %% "spark-core" % spark3Version % "provided", + libraryDependencies += "org.apache.spark" %% "spark-sql" % spark3Version % "provided", libraryDependencies += "com.amazonaws" % "aws-java-sdk-s3" % "1.12.470" % "provided", libraryDependencies += "org.apache.iceberg" %% "iceberg-spark-runtime-3.5" % "1.5.0" % "provided", libraryDependencies += "io.delta" %% "delta-spark" % "3.2.0" % "provided", - + // Assembly configuration to create fat JAR with common code assembly / assemblyJarName := s"${name.value}_${scalaBinaryVersion.value}-${version.value}.jar", // Exclude Scala library from assembly - Spark provides its own Scala runtime @@ -96,8 +101,8 @@ lazy val pluginspark3 = (project in file("pluginspark3")) Compile / unmanagedResourceDirectories += (plugin / Compile / resourceDirectory).value, libraryDependencies += "org.scalatest" %% "scalatest-funsuite" % "3.2.17" % Test, libraryDependencies += "org.scalatest" %% "scalatest-shouldmatchers" % "3.2.17" % Test, - libraryDependencies += "org.apache.spark" %% "spark-core" % "3.5.1" % Test, - libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.5.1" % Test, + libraryDependencies += "org.apache.spark" %% "spark-core" % spark3Version % Test, + libraryDependencies += "org.apache.spark" %% "spark-sql" % spark3Version % Test, // Include source and resources from plugin directory for tests Test / unmanagedSourceDirectories += (plugin / Compile / sourceDirectory).value / "scala", @@ -107,6 +112,24 @@ lazy val pluginspark3 = (project in file("pluginspark3")) // Run test suites sequentially — parallel suites share the SparkSession via getOrCreate() // and one suite stopping the session causes NPEs in concurrently-running suites Test / parallelExecution := false, + createPythonVenv := { + val venvDir = new java.io.File(pythonVenvDir) + val log = streams.value.log + if (!venvDir.exists()) { + log.info(s"Creating Python venv at $pythonVenvDir ...") + val rc = Process(Seq(pythonExec, "-m", "venv", pythonVenvDir)).! + if (rc != 0) sys.error(s"Failed to create Python venv at $pythonVenvDir") + } + val pip = s"$pythonVenvDir/bin/pip" + log.info(s"Installing pyspark==$spark3Version pandas pyarrow into venv...") + val rc = Process(Seq(pip, "install", "--quiet", s"pyspark==$spark3Version", "pandas", "pyarrow")).! + if (rc != 0) sys.error("pip install failed") + }, + Test / compile := (Test / compile).dependsOn(createPythonVenv).value, + Test / testOnly := (Test / testOnly).dependsOn(createPythonVenv).evaluated, + Test / javaOptions ++= Seq( + s"-Ddataflint.projectRoot=${baseDirectory.value.getParentFile.toString}", + ), Test / javaOptions ++= { // --add-opens is not supported on Java 8 (spec version starts with "1.") if (sys.props("java.specification.version").startsWith("1.")) Seq.empty diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala index 49af685..5d8ca58 100644 --- a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala @@ -16,22 +16,13 @@ import java.nio.file.Paths * for the subprocess — no manual gateway setup needed. The Python script then connects * to this JVM via launch_gateway() and accesses the session through DataFlintStaticSession. * - * Requires: .venv with pyspark, pandas, pyarrow installed. - * python3 -m venv .venv && .venv/bin/pip install pyspark pandas pyarrow + * Python dependencies (pyspark, pandas, pyarrow) are installed automatically into + * .venv at test startup if not already present. */ class DataFlintPythonIntegrationSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll with DataFlintTestHelper { - // pluginspark3 tests run with CWD = spark-plugin/pluginspark3/, so go up one level - // to reach the project root where .venv and pyspark-testing live. - private val projectRoot = Paths.get("").toAbsolutePath//.getParent - - private val venvPython: String = { - val p = projectRoot.resolve(Paths.get(".venv", "bin", "python3")) - require(p.toFile.exists(), - s"Python venv not found at $p\n" + - "Run: python3 -m venv .venv && .venv/bin/pip install pyspark pandas pyarrow") - p.toString - } + private val projectRoot = Paths.get(sys.props.getOrElse("dataflint.projectRoot", "")) + private val venvPython: String = System.getProperty("java.io.tmpdir") + "/dataflint-pyspark-venv/bin/python3" private val scriptPath: String = projectRoot.resolve( diff --git a/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py b/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py new file mode 100644 index 0000000..468e8b7 --- /dev/null +++ b/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py @@ -0,0 +1,85 @@ +""" +Integration test helper: registers one DataFlint Python exec node scenario as a temp view. + +Usage: pass the test name as the first argument (sys.argv[1]): + batch_eval — @udf → registers batch_eval_view + arrow_eval — @pandas_udf scalar → registers arrow_eval_view + flat_map_groups — applyInPandas → registers flat_map_groups_view + flat_map_cogroups— cogroup.applyIn… → registers flat_map_cogroups_view + +The Scala test (DataFlintPythonIntegrationSpec) calls PythonRunner.main with one of +the above names, then checks the corresponding view's executedPlan. +""" +import sys +import pyspark +import pyspark.java_gateway +from pyspark.sql import SparkSession +from pyspark.sql.functions import udf, pandas_udf +from pyspark.sql.types import LongType +import pandas as pd + +test_name = sys.argv[1] if len(sys.argv) > 1 else "" + +gateway = pyspark.java_gateway.launch_gateway() +static = gateway.jvm.org.apache.spark.dataflint.DataFlintStaticSession +jsc = static.javaSparkContext() +spark_jvm = static.session() + +conf = pyspark.conf.SparkConf(True, gateway.jvm, jsc.getConf()) +sc = pyspark.SparkContext(gateway=gateway, jsc=jsc, conf=conf) +spark = SparkSession(sc, jsparkSession=spark_jvm) + +df = spark.createDataFrame( + [(1, "a"), (2, "b"), (3, "a"), (4, "b")], + ["id", "cat"] +) + + +def register_batch_eval(): + @udf(returnType=LongType()) + def double_udf(x): + return x * 2 + df.select(double_udf("id")).createOrReplaceTempView("batch_eval_view") + + +def register_arrow_eval(): + @pandas_udf(LongType()) + def double_pandas_udf(s: pd.Series) -> pd.Series: + return s * 2 + df.select(double_pandas_udf("id")).createOrReplaceTempView("arrow_eval_view") + + +def register_flat_map_groups(): + def identity_group(key, pdf): + return pdf + df.groupby("cat").applyInPandas( + identity_group, schema="id long, cat string" + ).createOrReplaceTempView("flat_map_groups_view") + + +def register_flat_map_cogroups(): + df2 = spark.createDataFrame( + [(1, "x"), (2, "y"), (3, "z"), (4, "w")], + ["id", "label"] + ) + def cogroup_fn(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: + left = left.copy() + left["label"] = right["label"].values[0] if len(right) > 0 else None + return left[["id", "cat", "label"]] + df.groupby("id").cogroup(df2.groupby("id")).applyInPandas( + cogroup_fn, schema="id long, cat string, label string" + ).createOrReplaceTempView("flat_map_cogroups_view") + + +_tests = { + "batch_eval": register_batch_eval, + "arrow_eval": register_arrow_eval, + "flat_map_groups": register_flat_map_groups, + "flat_map_cogroups": register_flat_map_cogroups, +} + +if test_name not in _tests: + print(f"Unknown test '{test_name}'. Valid options: {list(_tests)}", file=sys.stderr) + sys.exit(1) + +_tests[test_name]() \ No newline at end of file From 56b0e63dc635a10bfa0ff3632b8ed0ce2489353a Mon Sep 17 00:00:00 2001 From: Avi Minsky Date: Sun, 22 Mar 2026 15:42:26 +0200 Subject: [PATCH 07/11] updated ci to install p3.11 --- .github/workflows/ci.yml | 5 +++++ spark-plugin/build.sbt | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index df1049e..756f46e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -59,6 +59,11 @@ jobs: run: npm run test working-directory: ./spark-ui + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: Build and test plugin run: sbt +test working-directory: ./spark-plugin diff --git a/spark-plugin/build.sbt b/spark-plugin/build.sbt index 1847992..d12c170 100644 --- a/spark-plugin/build.sbt +++ b/spark-plugin/build.sbt @@ -5,7 +5,7 @@ import scala.sys.process._ lazy val versionNum: String = "0.8.8" lazy val spark3Version: String = "3.5.1" lazy val pythonVenvDir: String = System.getProperty("java.io.tmpdir") + "/dataflint-pyspark-venv" -lazy val pythonExec: String = "python3.11" +lazy val pythonExec: String = sys.env.getOrElse("DATAFLINT_PYTHON_EXEC", "python3.11") val createPythonVenv = taskKey[Unit]("Create Python venv and install pyspark test dependencies") lazy val scala212 = "2.12.20" lazy val scala213 = "2.13.16" From bbabb83734924896532ceff77df341fa6c7021c1 Mon Sep 17 00:00:00 2001 From: Avi Minsky Date: Sun, 22 Mar 2026 17:54:06 +0200 Subject: [PATCH 08/11] pyspark testing infra from scala --- .../dataflint/DataFlintStaticSession.scala | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintStaticSession.scala diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintStaticSession.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintStaticSession.scala new file mode 100644 index 0000000..12d57a4 --- /dev/null +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintStaticSession.scala @@ -0,0 +1,29 @@ +package org.apache.spark.dataflint + +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.sql.SparkSession + +/** + * Holds the active SparkSession so Python integration tests can access it via Py4J. + * + * Usage in Scala test: + * DataFlintStaticSession.set(spark) + * PythonRunner.main(Array(scriptPath, "")) + * + * Usage in Python script (after PythonRunner sets PYSPARK_GATEWAY_PORT): + * gateway = pyspark.java_gateway.launch_gateway() + * static = gateway.jvm.org.apache.spark.dataflint.DataFlintStaticSession + * sc = pyspark.SparkContext(gateway=gateway, jsc=static.javaSparkContext()) + * spark = SparkSession(sc, jsparkSession=static.session()) + */ +object DataFlintStaticSession { + @volatile private var _session: SparkSession = _ + + def set(session: SparkSession): Unit = { _session = session } + def clear(): Unit = { _session = null } + + // Scala 2 generates static forwarders — accessible as: + // gateway.jvm.org.apache.spark.dataflint.DataFlintStaticSession.session() + def session: SparkSession = _session + def javaSparkContext: JavaSparkContext = new JavaSparkContext(_session.sparkContext) +} \ No newline at end of file From 944410a608175bc6f11518dd98394356be674d40 Mon Sep 17 00:00:00 2001 From: Avi Minsky Date: Mon, 23 Mar 2026 15:05:46 +0200 Subject: [PATCH 09/11] removed readln --- spark-plugin/pyspark-testing/dataflint_pyspark_example.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/spark-plugin/pyspark-testing/dataflint_pyspark_example.py b/spark-plugin/pyspark-testing/dataflint_pyspark_example.py index ccccfde..a7894ce 100755 --- a/spark-plugin/pyspark-testing/dataflint_pyspark_example.py +++ b/spark-plugin/pyspark-testing/dataflint_pyspark_example.py @@ -370,7 +370,4 @@ def apply_category_discount(left_pdf, right_pdf): print("\n" + "="*80) print("Done!") -print("="*80) - -input("\nPress Enter to exit...") -spark.stop() \ No newline at end of file +print("="*80) \ No newline at end of file From 6491fb2fcf465b7ec5975f3baf8026d30e58a297 Mon Sep 17 00:00:00 2001 From: Avi Minsky Date: Mon, 23 Mar 2026 16:10:47 +0200 Subject: [PATCH 10/11] added readln --- spark-plugin/pyspark-testing/dataflint_pyspark_example.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/spark-plugin/pyspark-testing/dataflint_pyspark_example.py b/spark-plugin/pyspark-testing/dataflint_pyspark_example.py index a7894ce..ccccfde 100755 --- a/spark-plugin/pyspark-testing/dataflint_pyspark_example.py +++ b/spark-plugin/pyspark-testing/dataflint_pyspark_example.py @@ -370,4 +370,7 @@ def apply_category_discount(left_pdf, right_pdf): print("\n" + "="*80) print("Done!") -print("="*80) \ No newline at end of file +print("="*80) + +input("\nPress Enter to exit...") +spark.stop() \ No newline at end of file From 69629056340316fa542d475e5569712e24788481 Mon Sep 17 00:00:00 2001 From: Avi Minsky Date: Mon, 23 Mar 2026 16:56:20 +0200 Subject: [PATCH 11/11] updated sbt with venv creation on test --- spark-plugin/build.sbt | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/spark-plugin/build.sbt b/spark-plugin/build.sbt index 9e85afe..54f37ad 100644 --- a/spark-plugin/build.sbt +++ b/spark-plugin/build.sbt @@ -40,8 +40,8 @@ lazy val plugin = (project in file("plugin")) } else { versionNum + "-SNAPSHOT" }), - libraryDependencies += "org.apache.spark" %% "spark-core" % "3.5.1" % "provided", - libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.5.1" % "provided", + libraryDependencies += "org.apache.spark" %% "spark-core" % spark3Version % "provided", + libraryDependencies += "org.apache.spark" %% "spark-sql" % spark3Version % "provided", libraryDependencies += "com.amazonaws" % "aws-java-sdk-s3" % "1.12.470" % "provided", libraryDependencies += "org.apache.iceberg" %% "iceberg-spark-runtime-3.5" % "1.5.0" % "provided", libraryDependencies += "io.delta" %% "delta-spark" % "3.2.0" % "provided", @@ -66,12 +66,12 @@ lazy val pluginspark3 = (project in file("pluginspark3")) } else { versionNum + "-SNAPSHOT" }), - libraryDependencies += "org.apache.spark" %% "spark-core" % "3.5.1" % "provided", - libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.5.1" % "provided", + libraryDependencies += "org.apache.spark" %% "spark-core" % spark3Version % "provided", + libraryDependencies += "org.apache.spark" %% "spark-sql" % spark3Version % "provided", libraryDependencies += "com.amazonaws" % "aws-java-sdk-s3" % "1.12.470" % "provided", libraryDependencies += "org.apache.iceberg" %% "iceberg-spark-runtime-3.5" % "1.5.0" % "provided", libraryDependencies += "io.delta" %% "delta-spark" % "3.2.0" % "provided", - + // Assembly configuration to create fat JAR with common code assembly / assemblyJarName := s"${name.value}_${scalaBinaryVersion.value}-${version.value}.jar", // Exclude Scala library from assembly - Spark provides its own Scala runtime @@ -101,8 +101,8 @@ lazy val pluginspark3 = (project in file("pluginspark3")) Compile / unmanagedResourceDirectories += (plugin / Compile / resourceDirectory).value, libraryDependencies += "org.scalatest" %% "scalatest-funsuite" % "3.2.17" % Test, libraryDependencies += "org.scalatest" %% "scalatest-shouldmatchers" % "3.2.17" % Test, - libraryDependencies += "org.apache.spark" %% "spark-core" % "3.5.1" % Test, - libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.5.1" % Test, + libraryDependencies += "org.apache.spark" %% "spark-core" % spark3Version % Test, + libraryDependencies += "org.apache.spark" %% "spark-sql" % spark3Version % Test, // Include source and resources from plugin directory for tests Test / unmanagedSourceDirectories += (plugin / Compile / sourceDirectory).value / "scala", @@ -112,6 +112,24 @@ lazy val pluginspark3 = (project in file("pluginspark3")) // Run test suites sequentially — parallel suites share the SparkSession via getOrCreate() // and one suite stopping the session causes NPEs in concurrently-running suites Test / parallelExecution := false, + createPythonVenv := { + val venvDir = new java.io.File(pythonVenvDir) + val log = streams.value.log + if (!venvDir.exists()) { + log.info(s"Creating Python venv at $pythonVenvDir ...") + val rc = Process(Seq(pythonExec, "-m", "venv", pythonVenvDir)).! + if (rc != 0) sys.error(s"Failed to create Python venv at $pythonVenvDir") + } + val pip = s"$pythonVenvDir/bin/pip" + log.info(s"Installing pyspark==$spark3Version pandas pyarrow into venv...") + val rc = Process(Seq(pip, "install", "--quiet", s"pyspark==$spark3Version", "pandas", "pyarrow")).! + if (rc != 0) sys.error("pip install failed") + }, + Test / compile := (Test / compile).dependsOn(createPythonVenv).value, + Test / testOnly := (Test / testOnly).dependsOn(createPythonVenv).evaluated, + Test / javaOptions ++= Seq( + s"-Ddataflint.projectRoot=${baseDirectory.value.getParentFile.toString}", + ), Test / javaOptions ++= { // --add-opens is not supported on Java 8 (spec version starts with "1.") if (sys.props("java.specification.version").startsWith("1.")) Seq.empty