From 7ae3f18a9ffc6b4eeb275a24fe49e349238a0cd3 Mon Sep 17 00:00:00 2001 From: machichima Date: Wed, 4 Jun 2025 07:46:42 -0600 Subject: [PATCH 1/5] fix: use {} if sparkConf not set -e Signed-off-by: machichima --- plugins/flytekit-spark/flytekitplugins/spark/connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/connector.py b/plugins/flytekit-spark/flytekitplugins/spark/connector.py index 44acb7c0f2..895c7d153d 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/connector.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/connector.py @@ -40,7 +40,7 @@ def _get_databricks_job_spec(task_template: TaskTemplate) -> dict: if not new_cluster.get("docker_image"): new_cluster["docker_image"] = {"url": container.image} if not new_cluster.get("spark_conf"): - new_cluster["spark_conf"] = custom["sparkConf"] + new_cluster["spark_conf"] = custom.get("sparkConf", {}) if not new_cluster.get("spark_env_vars"): new_cluster["spark_env_vars"] = {k: v for k, v in envs.items()} else: From 93ebc24a4662bb3ffb2bb07a9deb3f3be88e65fe Mon Sep 17 00:00:00 2001 From: machichima Date: Wed, 4 Jun 2025 07:46:59 -0600 Subject: [PATCH 2/5] test: for DatbricksV2 with and without spark_conf -e Signed-off-by: machichima --- .../flytekit-spark/tests/test_spark_task.py | 42 ++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index ff3e1797b7..7198a4dec0 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -10,7 +10,7 @@ from flytekit import PodTemplate from flytekit.core import context_manager from flytekitplugins.spark import Spark -from flytekitplugins.spark.task import Databricks, new_spark_session +from flytekitplugins.spark.task import Databricks, DatabricksV2, new_spark_session from pyspark.sql import SparkSession import flytekit @@ -135,6 +135,46 @@ def my_databricks(a: int) -> int: assert my_databricks(a=3) == 3 +@pytest.mark.parametrize("spark_conf", [None, {"spark": "2"}]) +def test_databricks_v2(reset_spark_session, spark_conf): + databricks_conf = { + "name": "flytekit databricks plugin example", + "new_cluster": { + "spark_version": "11.0.x-scala2.12", + "node_type_id": "r3.xlarge", + "aws_attributes": {"availability": "ON_DEMAND"}, + "num_workers": 4, + "docker_image": {"url": "pingsutw/databricks:latest"}, + }, + "timeout_seconds": 3600, + "max_retries": 1, + "spark_python_task": { + "python_file": "dbfs:///FileStore/tables/entrypoint-1.py", + "parameters": "ls", + }, + } + + databricks_instance = "account.cloud.databricks.com" + + @task( + task_config=DatabricksV2( + databricks_conf=databricks_conf, + databricks_instance=databricks_instance, + spark_conf=spark_conf, + ) + ) + def my_databricks(a: int) -> int: + session = flytekit.current_context().spark_session + assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" + return a + + assert my_databricks.task_config is not None + assert my_databricks.task_config.databricks_conf == databricks_conf + assert my_databricks.task_config.databricks_instance == databricks_instance + assert my_databricks.task_config.spark_conf == (spark_conf or {}) + assert my_databricks(a=3) == 3 + + def test_new_spark_session(): name = "SessionName" spark_conf = {"spark1": "1", "spark2": "2"} From 42cdb6052f0c2e9aa8546c117e6aea62971dad7c Mon Sep 17 00:00:00 2001 From: machichima Date: Tue, 10 Jun 2025 22:11:03 +0800 Subject: [PATCH 3/5] fix: directly return for branch node for is(not)parent node -e Signed-off-by: machichima --- flytekit/remote/remote.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index e45c1d7dd0..486f2645bf 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -2610,6 +2610,14 @@ def sync_node_execution( execution._interface = launched_exec._flyte_workflow.interface return execution + # Handle the case where it's a branch node + if execution._node.branch_node is not None: + logger.info( + "Skipping branch node execution for now - branch nodes will " + "not have inputs and outputs filled in" + ) + return execution + # If a node ran a static subworkflow or a dynamic subworkflow then the parent flag will be set. if execution.metadata.is_parent_node: # We'll need to query child node executions regardless since this is a parent node @@ -2649,14 +2657,6 @@ def sync_node_execution( for cne in child_node_executions ] execution._interface = sub_flyte_workflow.interface - - # Handle the case where it's a branch node - elif execution._node.branch_node is not None: - logger.info( - "Skipping branch node execution for now - branch nodes will " - "not have inputs and outputs filled in" - ) - return execution else: logger.error(f"NE {execution} undeterminable, {type(execution._node)}, {execution._node}") raise ValueError(f"Node execution undeterminable, entity has type {type(execution._node)}") From 000b84def2e037d7cc4c89324b8f93fa307aa4f8 Mon Sep 17 00:00:00 2001 From: machichima Date: Thu, 12 Jun 2025 20:24:41 +0800 Subject: [PATCH 4/5] test: for running conditional workflow with remote.wait -e Signed-off-by: machichima --- .../integration/remote/test_remote.py | 9 +++++ .../workflows/basic/conditional_workflow.py | 39 +++++++++++++++++++ 2 files changed, 48 insertions(+) create mode 100644 tests/flytekit/integration/remote/workflows/basic/conditional_workflow.py diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 18481d9e69..3e7457f3b3 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -1358,3 +1358,12 @@ def test_run_wf_with_resource_requests_override(register): ], limits=[], ) + + +def test_conditional_workflow(): + execution_id = run("conditional_workflow.py", "wf") + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + execution = remote.fetch_execution(name=execution_id) + execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5)) + print("Execution Error:", execution.error) + assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}" diff --git a/tests/flytekit/integration/remote/workflows/basic/conditional_workflow.py b/tests/flytekit/integration/remote/workflows/basic/conditional_workflow.py new file mode 100644 index 0000000000..d052a5fc65 --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/conditional_workflow.py @@ -0,0 +1,39 @@ +import flytekit as fl +from flytekit import conditional +from flytekit.core.task import Echo + +echo_radius = Echo(name="noop", inputs={"radius": float}) + + +@fl.task +def calculate_circle_circumference(radius: float) -> float: + return 2 * 3.14 * radius # Task to calculate the circumference of a circle + + +@fl.task +def calculate_circle_area(radius: float) -> float: + return 3.14 * radius * radius # Task to calculate the area of a circle + + +@fl.task +def nop(radius: float) -> float: + return radius # Task that does nothing, effectively a no-op + + +@fl.workflow +def wf(radius: float = 0.5, get_area: bool = False, get_circumference: bool = True): + echoed_radius = nop(radius=radius) + ( + conditional("if_area") + .if_(get_area.is_true()) + .then(calculate_circle_area(radius=radius)) + .else_() + .then(echo_radius(echoed_radius)) + ) + ( + conditional("if_circumference") + .if_(get_circumference.is_true()) + .then(calculate_circle_circumference(radius=echoed_radius)) + .else_() + .then(echo_radius(echoed_radius)) + ) From e22a920f9f2b9dd56f34e389696ab5acf2691793 Mon Sep 17 00:00:00 2001 From: machichima Date: Thu, 12 Jun 2025 20:37:30 +0800 Subject: [PATCH 5/5] refactor: lint -e Signed-off-by: machichima --- flytekit/remote/remote.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 486f2645bf..5d962e9df3 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -2613,8 +2613,7 @@ def sync_node_execution( # Handle the case where it's a branch node if execution._node.branch_node is not None: logger.info( - "Skipping branch node execution for now - branch nodes will " - "not have inputs and outputs filled in" + "Skipping branch node execution for now - branch nodes will " "not have inputs and outputs filled in" ) return execution