diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index e45c1d7dd0..5d962e9df3 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -2610,6 +2610,13 @@ def sync_node_execution( execution._interface = launched_exec._flyte_workflow.interface return execution + # Handle the case where it's a branch node + if execution._node.branch_node is not None: + logger.info( + "Skipping branch node execution for now - branch nodes will " "not have inputs and outputs filled in" + ) + return execution + # If a node ran a static subworkflow or a dynamic subworkflow then the parent flag will be set. if execution.metadata.is_parent_node: # We'll need to query child node executions regardless since this is a parent node @@ -2649,14 +2656,6 @@ def sync_node_execution( for cne in child_node_executions ] execution._interface = sub_flyte_workflow.interface - - # Handle the case where it's a branch node - elif execution._node.branch_node is not None: - logger.info( - "Skipping branch node execution for now - branch nodes will " - "not have inputs and outputs filled in" - ) - return execution else: logger.error(f"NE {execution} undeterminable, {type(execution._node)}, {execution._node}") raise ValueError(f"Node execution undeterminable, entity has type {type(execution._node)}") diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 18481d9e69..3e7457f3b3 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -1358,3 +1358,12 @@ def test_run_wf_with_resource_requests_override(register): ], limits=[], ) + + +def test_conditional_workflow(): + execution_id = run("conditional_workflow.py", "wf") + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + execution = remote.fetch_execution(name=execution_id) + execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5)) + print("Execution Error:", execution.error) + assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}" diff --git a/tests/flytekit/integration/remote/workflows/basic/conditional_workflow.py b/tests/flytekit/integration/remote/workflows/basic/conditional_workflow.py new file mode 100644 index 0000000000..d052a5fc65 --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/conditional_workflow.py @@ -0,0 +1,39 @@ +import flytekit as fl +from flytekit import conditional +from flytekit.core.task import Echo + +echo_radius = Echo(name="noop", inputs={"radius": float}) + + +@fl.task +def calculate_circle_circumference(radius: float) -> float: + return 2 * 3.14 * radius # Task to calculate the circumference of a circle + + +@fl.task +def calculate_circle_area(radius: float) -> float: + return 3.14 * radius * radius # Task to calculate the area of a circle + + +@fl.task +def nop(radius: float) -> float: + return radius # Task that does nothing, effectively a no-op + + +@fl.workflow +def wf(radius: float = 0.5, get_area: bool = False, get_circumference: bool = True): + echoed_radius = nop(radius=radius) + ( + conditional("if_area") + .if_(get_area.is_true()) + .then(calculate_circle_area(radius=radius)) + .else_() + .then(echo_radius(echoed_radius)) + ) + ( + conditional("if_circumference") + .if_(get_circumference.is_true()) + .then(calculate_circle_circumference(radius=echoed_radius)) + .else_() + .then(echo_radius(echoed_radius)) + )