Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ jobs:
run: npm run test
working-directory: ./spark-ui

- name: Set up Python 3.11
uses: actions/setup-python@v5
with:
python-version: '3.11'

- name: Build and test plugin
run: sbt +test
working-directory: ./spark-plugin
Expand Down
37 changes: 30 additions & 7 deletions spark-plugin/build.sbt
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import xerial.sbt.Sonatype._
import sbtassembly.AssemblyPlugin.autoImport._
import scala.sys.process._

lazy val versionNum: String = "0.8.9"
lazy val spark3Version: String = "3.5.1"
lazy val pythonVenvDir: String = System.getProperty("java.io.tmpdir") + "/dataflint-pyspark-venv"
lazy val pythonExec: String = sys.env.getOrElse("DATAFLINT_PYTHON_EXEC", "python3.11")
val createPythonVenv = taskKey[Unit]("Create Python venv and install pyspark test dependencies")
lazy val scala212 = "2.12.20"
lazy val scala213 = "2.13.16"
lazy val supportedScalaVersions = List(scala212, scala213)
Expand Down Expand Up @@ -35,8 +40,8 @@ lazy val plugin = (project in file("plugin"))
} else {
versionNum + "-SNAPSHOT"
}),
libraryDependencies += "org.apache.spark" %% "spark-core" % "3.5.1" % "provided",
libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.5.1" % "provided",
libraryDependencies += "org.apache.spark" %% "spark-core" % spark3Version % "provided",
libraryDependencies += "org.apache.spark" %% "spark-sql" % spark3Version % "provided",
libraryDependencies += "com.amazonaws" % "aws-java-sdk-s3" % "1.12.470" % "provided",
libraryDependencies += "org.apache.iceberg" %% "iceberg-spark-runtime-3.5" % "1.5.0" % "provided",
libraryDependencies += "io.delta" %% "delta-spark" % "3.2.0" % "provided",
Expand All @@ -61,12 +66,12 @@ lazy val pluginspark3 = (project in file("pluginspark3"))
} else {
versionNum + "-SNAPSHOT"
}),
libraryDependencies += "org.apache.spark" %% "spark-core" % "3.5.1" % "provided",
libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.5.1" % "provided",
libraryDependencies += "org.apache.spark" %% "spark-core" % spark3Version % "provided",
libraryDependencies += "org.apache.spark" %% "spark-sql" % spark3Version % "provided",
libraryDependencies += "com.amazonaws" % "aws-java-sdk-s3" % "1.12.470" % "provided",
libraryDependencies += "org.apache.iceberg" %% "iceberg-spark-runtime-3.5" % "1.5.0" % "provided",
libraryDependencies += "io.delta" %% "delta-spark" % "3.2.0" % "provided",

// Assembly configuration to create fat JAR with common code
assembly / assemblyJarName := s"${name.value}_${scalaBinaryVersion.value}-${version.value}.jar",
// Exclude Scala library from assembly - Spark provides its own Scala runtime
Expand Down Expand Up @@ -96,8 +101,8 @@ lazy val pluginspark3 = (project in file("pluginspark3"))
Compile / unmanagedResourceDirectories += (plugin / Compile / resourceDirectory).value,
libraryDependencies += "org.scalatest" %% "scalatest-funsuite" % "3.2.17" % Test,
libraryDependencies += "org.scalatest" %% "scalatest-shouldmatchers" % "3.2.17" % Test,
libraryDependencies += "org.apache.spark" %% "spark-core" % "3.5.1" % Test,
libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.5.1" % Test,
libraryDependencies += "org.apache.spark" %% "spark-core" % spark3Version % Test,
libraryDependencies += "org.apache.spark" %% "spark-sql" % spark3Version % Test,

// Include source and resources from plugin directory for tests
Test / unmanagedSourceDirectories += (plugin / Compile / sourceDirectory).value / "scala",
Expand All @@ -107,6 +112,24 @@ lazy val pluginspark3 = (project in file("pluginspark3"))
// Run test suites sequentially — parallel suites share the SparkSession via getOrCreate()
// and one suite stopping the session causes NPEs in concurrently-running suites
Test / parallelExecution := false,
createPythonVenv := {
val venvDir = new java.io.File(pythonVenvDir)
val log = streams.value.log
if (!venvDir.exists()) {
log.info(s"Creating Python venv at $pythonVenvDir ...")
val rc = Process(Seq(pythonExec, "-m", "venv", pythonVenvDir)).!
if (rc != 0) sys.error(s"Failed to create Python venv at $pythonVenvDir")
}
val pip = s"$pythonVenvDir/bin/pip"
log.info(s"Installing pyspark==$spark3Version pandas pyarrow into venv...")
val rc = Process(Seq(pip, "install", "--quiet", s"pyspark==$spark3Version", "pandas", "pyarrow")).!
if (rc != 0) sys.error("pip install failed")
},
Test / compile := (Test / compile).dependsOn(createPythonVenv).value,
Test / testOnly := (Test / testOnly).dependsOn(createPythonVenv).evaluated,
Test / javaOptions ++= Seq(
s"-Ddataflint.projectRoot=${baseDirectory.value.getParentFile.toString}",
),
Test / javaOptions ++= {
// --add-opens is not supported on Java 8 (spec version starts with "1.")
if (sys.props("java.specification.version").startsWith("1.")) Seq.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers

class DataFlintPythonExecSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll {
class DataFlintPythonExecSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll with DataFlintTestHelper {

private var spark: SparkSession = _

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package org.apache.spark.dataflint

import org.apache.spark.deploy.PythonRunner
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.SparkPlan
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers

import java.nio.file.Paths

/**
* Integration test that runs a Python script against the instrumented SparkSession.
*
* PythonRunner.main() creates a Py4JServer internally and sets PYSPARK_GATEWAY_PORT
* for the subprocess — no manual gateway setup needed. The Python script then connects
* to this JVM via launch_gateway() and accesses the session through DataFlintStaticSession.
*
* Python dependencies (pyspark, pandas, pyarrow) are installed automatically into
* .venv at test startup if not already present.
*/
class DataFlintPythonIntegrationSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll with DataFlintTestHelper {

private val projectRoot = Paths.get(sys.props.getOrElse("dataflint.projectRoot", ""))
private val venvPython: String = System.getProperty("java.io.tmpdir") + "/dataflint-pyspark-venv/bin/python3"

private val scriptPath: String =
projectRoot.resolve(
Paths.get("pyspark-testing", "dataflint_python_exec_integration_test.py")).toString

private var spark: SparkSession = _

override def beforeAll(): Unit = {
// Set before SparkSession creation so the conf is also applied to UDF workers.
// Also sets it as a system property so PythonRunner's internal `new SparkConf()`
// picks it up when resolving the Python executable for the subprocess.
System.setProperty("spark.pyspark.python", venvPython)

spark = SparkSession.builder()
.master("local[2]")
.appName("DataFlintPythonIntegrationSpec")
.config("spark.pyspark.python", venvPython)
.config(DataflintSparkUICommonLoader.INSTRUMENT_SPARK_ENABLED, "true")
.config(DataflintSparkUICommonLoader.INSTRUMENT_ARROW_EVAL_PYTHON_ENABLED, "true")
.config(DataflintSparkUICommonLoader.INSTRUMENT_BATCH_EVAL_PYTHON_ENABLED, "true")
.config(DataflintSparkUICommonLoader.INSTRUMENT_FLAT_MAP_GROUPS_PANDAS_ENABLED, "true")
.config(DataflintSparkUICommonLoader.INSTRUMENT_FLAT_MAP_COGROUPS_PANDAS_ENABLED, "true")
.config("spark.ui.enabled", "false")
.withExtensions(new DataFlintInstrumentationExtension)
.getOrCreate()

DataFlintStaticSession.set(spark)
}

override def afterAll(): Unit = {
DataFlintStaticSession.clear()
System.clearProperty("spark.pyspark.python")
if (spark != null) spark.stop()
}

/** Run the Python script for the given test name, then check the registered view. */
private def assertPythonNode(testName: String, view: String, expectedNode: String): Unit = {
PythonRunner.main(Array(scriptPath, "", testName))
val df = spark.table(view)
df.collect()
val node = finalPlan(df).collect {
case n: SparkPlan if n.getClass.getSimpleName.contains(expectedNode) => n
}.headOption
node shouldBe defined
node.get.metrics should contain key "duration"
node.get.metrics("duration").value should be > 0L
}

test("BatchEvalPython (@udf) is instrumented") {
assertPythonNode("batch_eval", "batch_eval_view", "DataFlintBatchEvalPython")
}

test("ArrowEvalPython (@pandas_udf scalar) is instrumented") {
assertPythonNode("arrow_eval", "arrow_eval_view", "DataFlintArrowEvalPython")
}

test("FlatMapGroupsInPandas (applyInPandas) is instrumented") {
assertPythonNode("flat_map_groups", "flat_map_groups_view", "DataFlintFlatMapGroupsInPandas")
}

test("FlatMapCoGroupsInPandas (cogroup applyInPandas) is instrumented") {
assertPythonNode("flat_map_cogroups", "flat_map_cogroups_view", "DataFlintFlatMapCoGroupsInPandas")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package org.apache.spark.dataflint

import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql.SparkSession

/**
* Holds the active SparkSession so Python integration tests can access it via Py4J.
*
* Usage in Scala test:
* DataFlintStaticSession.set(spark)
* PythonRunner.main(Array(scriptPath, ""))
*
* Usage in Python script (after PythonRunner sets PYSPARK_GATEWAY_PORT):
* gateway = pyspark.java_gateway.launch_gateway()
* static = gateway.jvm.org.apache.spark.dataflint.DataFlintStaticSession
* sc = pyspark.SparkContext(gateway=gateway, jsc=static.javaSparkContext())
* spark = SparkSession(sc, jsparkSession=static.session())
*/
object DataFlintStaticSession {
@volatile private var _session: SparkSession = _

def set(session: SparkSession): Unit = { _session = session }
def clear(): Unit = { _session = null }

// Scala 2 generates static forwarders — accessible as:
// gateway.jvm.org.apache.spark.dataflint.DataFlintStaticSession.session()
def session: SparkSession = _session
def javaSparkContext: JavaSparkContext = new JavaSparkContext(_session.sparkContext)
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
package org.apache.spark.dataflint

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.metric.SQLMetric
import org.scalatest.Assertions

case class MetricStats(total: Long, min: Long, med: Long, max: Long)

trait SqlMetricTestHelper extends Assertions {
trait DataFlintTestHelper extends Assertions {

// With AQE, executedPlan is AdaptiveSparkPlanExec. After collect(), finalPhysicalPlan holds
// the fully optimised plan. For plan-structure tests that don't execute the query, use
// queryExecution.sparkPlan instead (our strategy runs before AQE wraps the plan).
def finalPlan(df: DataFrame): SparkPlan = df.queryExecution.executedPlan match {
case aqe: AdaptiveSparkPlanExec => aqe.finalPhysicalPlan
case p => p
}

// Reads per-task total/min/med/max from the SQL status store for the given SQLMetric.
// createTimingMetric records each task's elapsed time individually; the store formats them as:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package org.apache.spark.dataflint

import org.apache.spark.sql.execution.{ExplicitRepartitionExtension, ExplicitRepartitionOps}
import org.apache.spark.sql.{DataFrame, Encoder, Encoders, SparkSession}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.window.DataFlintWindowExec
import org.apache.spark.sql.expressions.{Aggregator, Window}
import org.apache.spark.sql.functions.{col, rank, udaf}
Expand All @@ -28,7 +27,7 @@ private class SlowSumAggregator(fromSleep: Long, toSleep: Long) extends Aggregat
def outputEncoder: Encoder[Long] = Encoders.scalaLong
}

class DataFlintWindowExecSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll with SqlMetricTestHelper {
class DataFlintWindowExecSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll with DataFlintTestHelper {

private var spark: SparkSession = _

Expand All @@ -52,13 +51,6 @@ class DataFlintWindowExecSpec extends AnyFunSuite with Matchers with BeforeAndAf
if (spark != null) spark.stop()
}

// With AQE, executedPlan is AdaptiveSparkPlanExec. After collect(), finalPhysicalPlan holds
// the fully optimised plan. For plan-structure tests that don't execute the query, use
// queryExecution.sparkPlan instead (our strategy runs before AQE wraps the plan).
private def finalPlan(df: DataFrame) = df.queryExecution.executedPlan match {
case aqe: AdaptiveSparkPlanExec => aqe.finalPhysicalPlan
case p => p
}

test("DataFlintWindowPlannerStrategy replaces WindowExec with DataFlintWindowExec for SQL window") {
val session = spark
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Integration test helper: registers one DataFlint Python exec node scenario as a temp view.

Usage: pass the test name as the first argument (sys.argv[1]):
batch_eval — @udf → registers batch_eval_view
arrow_eval — @pandas_udf scalar → registers arrow_eval_view
flat_map_groups — applyInPandas → registers flat_map_groups_view
flat_map_cogroups— cogroup.applyIn… → registers flat_map_cogroups_view

The Scala test (DataFlintPythonIntegrationSpec) calls PythonRunner.main with one of
the above names, then checks the corresponding view's executedPlan.
"""
import sys
import pyspark
import pyspark.java_gateway
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, pandas_udf
from pyspark.sql.types import LongType
import pandas as pd

test_name = sys.argv[1] if len(sys.argv) > 1 else ""

gateway = pyspark.java_gateway.launch_gateway()
static = gateway.jvm.org.apache.spark.dataflint.DataFlintStaticSession
jsc = static.javaSparkContext()
spark_jvm = static.session()

conf = pyspark.conf.SparkConf(True, gateway.jvm, jsc.getConf())
sc = pyspark.SparkContext(gateway=gateway, jsc=jsc, conf=conf)
spark = SparkSession(sc, jsparkSession=spark_jvm)

df = spark.createDataFrame(
[(1, "a"), (2, "b"), (3, "a"), (4, "b")],
["id", "cat"]
)


def register_batch_eval():
@udf(returnType=LongType())
def double_udf(x):
return x * 2
df.select(double_udf("id")).createOrReplaceTempView("batch_eval_view")


def register_arrow_eval():
@pandas_udf(LongType())
def double_pandas_udf(s: pd.Series) -> pd.Series:
return s * 2
df.select(double_pandas_udf("id")).createOrReplaceTempView("arrow_eval_view")


def register_flat_map_groups():
def identity_group(key, pdf):
return pdf
df.groupby("cat").applyInPandas(
identity_group, schema="id long, cat string"
).createOrReplaceTempView("flat_map_groups_view")


def register_flat_map_cogroups():
df2 = spark.createDataFrame(
[(1, "x"), (2, "y"), (3, "z"), (4, "w")],
["id", "label"]
)
def cogroup_fn(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
left = left.copy()
left["label"] = right["label"].values[0] if len(right) > 0 else None
return left[["id", "cat", "label"]]
df.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
cogroup_fn, schema="id long, cat string, label string"
).createOrReplaceTempView("flat_map_cogroups_view")


_tests = {
"batch_eval": register_batch_eval,
"arrow_eval": register_arrow_eval,
"flat_map_groups": register_flat_map_groups,
"flat_map_cogroups": register_flat_map_cogroups,
}

if test_name not in _tests:
print(f"Unknown test '{test_name}'. Valid options: {list(_tests)}", file=sys.stderr)
sys.exit(1)

_tests[test_name]()
Loading