diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index df1049e..756f46e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -59,6 +59,11 @@ jobs: run: npm run test working-directory: ./spark-ui + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: Build and test plugin run: sbt +test working-directory: ./spark-plugin diff --git a/spark-plugin/build.sbt b/spark-plugin/build.sbt index 659dfca..54f37ad 100644 --- a/spark-plugin/build.sbt +++ b/spark-plugin/build.sbt @@ -1,7 +1,12 @@ import xerial.sbt.Sonatype._ import sbtassembly.AssemblyPlugin.autoImport._ +import scala.sys.process._ lazy val versionNum: String = "0.8.9" +lazy val spark3Version: String = "3.5.1" +lazy val pythonVenvDir: String = System.getProperty("java.io.tmpdir") + "/dataflint-pyspark-venv" +lazy val pythonExec: String = sys.env.getOrElse("DATAFLINT_PYTHON_EXEC", "python3.11") +val createPythonVenv = taskKey[Unit]("Create Python venv and install pyspark test dependencies") lazy val scala212 = "2.12.20" lazy val scala213 = "2.13.16" lazy val supportedScalaVersions = List(scala212, scala213) @@ -35,8 +40,8 @@ lazy val plugin = (project in file("plugin")) } else { versionNum + "-SNAPSHOT" }), - libraryDependencies += "org.apache.spark" %% "spark-core" % "3.5.1" % "provided", - libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.5.1" % "provided", + libraryDependencies += "org.apache.spark" %% "spark-core" % spark3Version % "provided", + libraryDependencies += "org.apache.spark" %% "spark-sql" % spark3Version % "provided", libraryDependencies += "com.amazonaws" % "aws-java-sdk-s3" % "1.12.470" % "provided", libraryDependencies += "org.apache.iceberg" %% "iceberg-spark-runtime-3.5" % "1.5.0" % "provided", libraryDependencies += "io.delta" %% "delta-spark" % "3.2.0" % "provided", @@ -61,12 +66,12 @@ lazy val pluginspark3 = (project in file("pluginspark3")) } else { versionNum + "-SNAPSHOT" }), - libraryDependencies += "org.apache.spark" %% "spark-core" % "3.5.1" % "provided", - libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.5.1" % "provided", + libraryDependencies += "org.apache.spark" %% "spark-core" % spark3Version % "provided", + libraryDependencies += "org.apache.spark" %% "spark-sql" % spark3Version % "provided", libraryDependencies += "com.amazonaws" % "aws-java-sdk-s3" % "1.12.470" % "provided", libraryDependencies += "org.apache.iceberg" %% "iceberg-spark-runtime-3.5" % "1.5.0" % "provided", libraryDependencies += "io.delta" %% "delta-spark" % "3.2.0" % "provided", - + // Assembly configuration to create fat JAR with common code assembly / assemblyJarName := s"${name.value}_${scalaBinaryVersion.value}-${version.value}.jar", // Exclude Scala library from assembly - Spark provides its own Scala runtime @@ -96,8 +101,8 @@ lazy val pluginspark3 = (project in file("pluginspark3")) Compile / unmanagedResourceDirectories += (plugin / Compile / resourceDirectory).value, libraryDependencies += "org.scalatest" %% "scalatest-funsuite" % "3.2.17" % Test, libraryDependencies += "org.scalatest" %% "scalatest-shouldmatchers" % "3.2.17" % Test, - libraryDependencies += "org.apache.spark" %% "spark-core" % "3.5.1" % Test, - libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.5.1" % Test, + libraryDependencies += "org.apache.spark" %% "spark-core" % spark3Version % Test, + libraryDependencies += "org.apache.spark" %% "spark-sql" % spark3Version % Test, // Include source and resources from plugin directory for tests Test / unmanagedSourceDirectories += (plugin / Compile / sourceDirectory).value / "scala", @@ -107,6 +112,24 @@ lazy val pluginspark3 = (project in file("pluginspark3")) // Run test suites sequentially — parallel suites share the SparkSession via getOrCreate() // and one suite stopping the session causes NPEs in concurrently-running suites Test / parallelExecution := false, + createPythonVenv := { + val venvDir = new java.io.File(pythonVenvDir) + val log = streams.value.log + if (!venvDir.exists()) { + log.info(s"Creating Python venv at $pythonVenvDir ...") + val rc = Process(Seq(pythonExec, "-m", "venv", pythonVenvDir)).! + if (rc != 0) sys.error(s"Failed to create Python venv at $pythonVenvDir") + } + val pip = s"$pythonVenvDir/bin/pip" + log.info(s"Installing pyspark==$spark3Version pandas pyarrow into venv...") + val rc = Process(Seq(pip, "install", "--quiet", s"pyspark==$spark3Version", "pandas", "pyarrow")).! + if (rc != 0) sys.error("pip install failed") + }, + Test / compile := (Test / compile).dependsOn(createPythonVenv).value, + Test / testOnly := (Test / testOnly).dependsOn(createPythonVenv).evaluated, + Test / javaOptions ++= Seq( + s"-Ddataflint.projectRoot=${baseDirectory.value.getParentFile.toString}", + ), Test / javaOptions ++= { // --add-opens is not supported on Java 8 (spec version starts with "1.") if (sys.props("java.specification.version").startsWith("1.")) Seq.empty diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonExecSpec.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonExecSpec.scala index 7f3e5f9..f3e3fba 100644 --- a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonExecSpec.scala +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonExecSpec.scala @@ -10,7 +10,7 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers -class DataFlintPythonExecSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll { +class DataFlintPythonExecSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll with DataFlintTestHelper { private var spark: SparkSession = _ diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala new file mode 100644 index 0000000..5d8ca58 --- /dev/null +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintPythonIntegrationSpec.scala @@ -0,0 +1,89 @@ +package org.apache.spark.dataflint + +import org.apache.spark.deploy.PythonRunner +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.SparkPlan +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +import java.nio.file.Paths + +/** + * Integration test that runs a Python script against the instrumented SparkSession. + * + * PythonRunner.main() creates a Py4JServer internally and sets PYSPARK_GATEWAY_PORT + * for the subprocess — no manual gateway setup needed. The Python script then connects + * to this JVM via launch_gateway() and accesses the session through DataFlintStaticSession. + * + * Python dependencies (pyspark, pandas, pyarrow) are installed automatically into + * .venv at test startup if not already present. + */ +class DataFlintPythonIntegrationSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll with DataFlintTestHelper { + + private val projectRoot = Paths.get(sys.props.getOrElse("dataflint.projectRoot", "")) + private val venvPython: String = System.getProperty("java.io.tmpdir") + "/dataflint-pyspark-venv/bin/python3" + + private val scriptPath: String = + projectRoot.resolve( + Paths.get("pyspark-testing", "dataflint_python_exec_integration_test.py")).toString + + private var spark: SparkSession = _ + + override def beforeAll(): Unit = { + // Set before SparkSession creation so the conf is also applied to UDF workers. + // Also sets it as a system property so PythonRunner's internal `new SparkConf()` + // picks it up when resolving the Python executable for the subprocess. + System.setProperty("spark.pyspark.python", venvPython) + + spark = SparkSession.builder() + .master("local[2]") + .appName("DataFlintPythonIntegrationSpec") + .config("spark.pyspark.python", venvPython) + .config(DataflintSparkUICommonLoader.INSTRUMENT_SPARK_ENABLED, "true") + .config(DataflintSparkUICommonLoader.INSTRUMENT_ARROW_EVAL_PYTHON_ENABLED, "true") + .config(DataflintSparkUICommonLoader.INSTRUMENT_BATCH_EVAL_PYTHON_ENABLED, "true") + .config(DataflintSparkUICommonLoader.INSTRUMENT_FLAT_MAP_GROUPS_PANDAS_ENABLED, "true") + .config(DataflintSparkUICommonLoader.INSTRUMENT_FLAT_MAP_COGROUPS_PANDAS_ENABLED, "true") + .config("spark.ui.enabled", "false") + .withExtensions(new DataFlintInstrumentationExtension) + .getOrCreate() + + DataFlintStaticSession.set(spark) + } + + override def afterAll(): Unit = { + DataFlintStaticSession.clear() + System.clearProperty("spark.pyspark.python") + if (spark != null) spark.stop() + } + + /** Run the Python script for the given test name, then check the registered view. */ + private def assertPythonNode(testName: String, view: String, expectedNode: String): Unit = { + PythonRunner.main(Array(scriptPath, "", testName)) + val df = spark.table(view) + df.collect() + val node = finalPlan(df).collect { + case n: SparkPlan if n.getClass.getSimpleName.contains(expectedNode) => n + }.headOption + node shouldBe defined + node.get.metrics should contain key "duration" + node.get.metrics("duration").value should be > 0L + } + + test("BatchEvalPython (@udf) is instrumented") { + assertPythonNode("batch_eval", "batch_eval_view", "DataFlintBatchEvalPython") + } + + test("ArrowEvalPython (@pandas_udf scalar) is instrumented") { + assertPythonNode("arrow_eval", "arrow_eval_view", "DataFlintArrowEvalPython") + } + + test("FlatMapGroupsInPandas (applyInPandas) is instrumented") { + assertPythonNode("flat_map_groups", "flat_map_groups_view", "DataFlintFlatMapGroupsInPandas") + } + + test("FlatMapCoGroupsInPandas (cogroup applyInPandas) is instrumented") { + assertPythonNode("flat_map_cogroups", "flat_map_cogroups_view", "DataFlintFlatMapCoGroupsInPandas") + } +} \ No newline at end of file diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintStaticSession.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintStaticSession.scala new file mode 100644 index 0000000..12d57a4 --- /dev/null +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintStaticSession.scala @@ -0,0 +1,29 @@ +package org.apache.spark.dataflint + +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.sql.SparkSession + +/** + * Holds the active SparkSession so Python integration tests can access it via Py4J. + * + * Usage in Scala test: + * DataFlintStaticSession.set(spark) + * PythonRunner.main(Array(scriptPath, "")) + * + * Usage in Python script (after PythonRunner sets PYSPARK_GATEWAY_PORT): + * gateway = pyspark.java_gateway.launch_gateway() + * static = gateway.jvm.org.apache.spark.dataflint.DataFlintStaticSession + * sc = pyspark.SparkContext(gateway=gateway, jsc=static.javaSparkContext()) + * spark = SparkSession(sc, jsparkSession=static.session()) + */ +object DataFlintStaticSession { + @volatile private var _session: SparkSession = _ + + def set(session: SparkSession): Unit = { _session = session } + def clear(): Unit = { _session = null } + + // Scala 2 generates static forwarders — accessible as: + // gateway.jvm.org.apache.spark.dataflint.DataFlintStaticSession.session() + def session: SparkSession = _session + def javaSparkContext: JavaSparkContext = new JavaSparkContext(_session.sparkContext) +} \ No newline at end of file diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/SqlMetricTestHelper.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintTestHelper.scala similarity index 61% rename from spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/SqlMetricTestHelper.scala rename to spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintTestHelper.scala index 88e6578..58b50b0 100644 --- a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/SqlMetricTestHelper.scala +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintTestHelper.scala @@ -1,12 +1,22 @@ package org.apache.spark.dataflint -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.metric.SQLMetric import org.scalatest.Assertions case class MetricStats(total: Long, min: Long, med: Long, max: Long) -trait SqlMetricTestHelper extends Assertions { +trait DataFlintTestHelper extends Assertions { + + // With AQE, executedPlan is AdaptiveSparkPlanExec. After collect(), finalPhysicalPlan holds + // the fully optimised plan. For plan-structure tests that don't execute the query, use + // queryExecution.sparkPlan instead (our strategy runs before AQE wraps the plan). + def finalPlan(df: DataFrame): SparkPlan = df.queryExecution.executedPlan match { + case aqe: AdaptiveSparkPlanExec => aqe.finalPhysicalPlan + case p => p + } // Reads per-task total/min/med/max from the SQL status store for the given SQLMetric. // createTimingMetric records each task's elapsed time individually; the store formats them as: diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintWindowExecSpec.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintWindowExecSpec.scala index ded8182..41fb52d 100644 --- a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintWindowExecSpec.scala +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintWindowExecSpec.scala @@ -2,7 +2,6 @@ package org.apache.spark.dataflint import org.apache.spark.sql.execution.{ExplicitRepartitionExtension, ExplicitRepartitionOps} import org.apache.spark.sql.{DataFrame, Encoder, Encoders, SparkSession} -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.window.DataFlintWindowExec import org.apache.spark.sql.expressions.{Aggregator, Window} import org.apache.spark.sql.functions.{col, rank, udaf} @@ -28,7 +27,7 @@ private class SlowSumAggregator(fromSleep: Long, toSleep: Long) extends Aggregat def outputEncoder: Encoder[Long] = Encoders.scalaLong } -class DataFlintWindowExecSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll with SqlMetricTestHelper { +class DataFlintWindowExecSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll with DataFlintTestHelper { private var spark: SparkSession = _ @@ -52,13 +51,6 @@ class DataFlintWindowExecSpec extends AnyFunSuite with Matchers with BeforeAndAf if (spark != null) spark.stop() } - // With AQE, executedPlan is AdaptiveSparkPlanExec. After collect(), finalPhysicalPlan holds - // the fully optimised plan. For plan-structure tests that don't execute the query, use - // queryExecution.sparkPlan instead (our strategy runs before AQE wraps the plan). - private def finalPlan(df: DataFrame) = df.queryExecution.executedPlan match { - case aqe: AdaptiveSparkPlanExec => aqe.finalPhysicalPlan - case p => p - } test("DataFlintWindowPlannerStrategy replaces WindowExec with DataFlintWindowExec for SQL window") { val session = spark diff --git a/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py b/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py new file mode 100644 index 0000000..468e8b7 --- /dev/null +++ b/spark-plugin/pyspark-testing/dataflint_python_exec_integration_test.py @@ -0,0 +1,85 @@ +""" +Integration test helper: registers one DataFlint Python exec node scenario as a temp view. + +Usage: pass the test name as the first argument (sys.argv[1]): + batch_eval — @udf → registers batch_eval_view + arrow_eval — @pandas_udf scalar → registers arrow_eval_view + flat_map_groups — applyInPandas → registers flat_map_groups_view + flat_map_cogroups— cogroup.applyIn… → registers flat_map_cogroups_view + +The Scala test (DataFlintPythonIntegrationSpec) calls PythonRunner.main with one of +the above names, then checks the corresponding view's executedPlan. +""" +import sys +import pyspark +import pyspark.java_gateway +from pyspark.sql import SparkSession +from pyspark.sql.functions import udf, pandas_udf +from pyspark.sql.types import LongType +import pandas as pd + +test_name = sys.argv[1] if len(sys.argv) > 1 else "" + +gateway = pyspark.java_gateway.launch_gateway() +static = gateway.jvm.org.apache.spark.dataflint.DataFlintStaticSession +jsc = static.javaSparkContext() +spark_jvm = static.session() + +conf = pyspark.conf.SparkConf(True, gateway.jvm, jsc.getConf()) +sc = pyspark.SparkContext(gateway=gateway, jsc=jsc, conf=conf) +spark = SparkSession(sc, jsparkSession=spark_jvm) + +df = spark.createDataFrame( + [(1, "a"), (2, "b"), (3, "a"), (4, "b")], + ["id", "cat"] +) + + +def register_batch_eval(): + @udf(returnType=LongType()) + def double_udf(x): + return x * 2 + df.select(double_udf("id")).createOrReplaceTempView("batch_eval_view") + + +def register_arrow_eval(): + @pandas_udf(LongType()) + def double_pandas_udf(s: pd.Series) -> pd.Series: + return s * 2 + df.select(double_pandas_udf("id")).createOrReplaceTempView("arrow_eval_view") + + +def register_flat_map_groups(): + def identity_group(key, pdf): + return pdf + df.groupby("cat").applyInPandas( + identity_group, schema="id long, cat string" + ).createOrReplaceTempView("flat_map_groups_view") + + +def register_flat_map_cogroups(): + df2 = spark.createDataFrame( + [(1, "x"), (2, "y"), (3, "z"), (4, "w")], + ["id", "label"] + ) + def cogroup_fn(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: + left = left.copy() + left["label"] = right["label"].values[0] if len(right) > 0 else None + return left[["id", "cat", "label"]] + df.groupby("id").cogroup(df2.groupby("id")).applyInPandas( + cogroup_fn, schema="id long, cat string, label string" + ).createOrReplaceTempView("flat_map_cogroups_view") + + +_tests = { + "batch_eval": register_batch_eval, + "arrow_eval": register_arrow_eval, + "flat_map_groups": register_flat_map_groups, + "flat_map_cogroups": register_flat_map_cogroups, +} + +if test_name not in _tests: + print(f"Unknown test '{test_name}'. Valid options: {list(_tests)}", file=sys.stderr) + sys.exit(1) + +_tests[test_name]() \ No newline at end of file