diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index f8a6929ea0e3b..3728994df3307 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -111,10 +111,13 @@ def execute(self, context: Context) -> list[str]: # steps may arrive as a string representing a list # e.g. if we used XCom or a file then: steps="[{ step1 }, { step2 }]" steps = self.steps + wait_for_completion = self.wait_for_completion if isinstance(steps, str): steps = ast.literal_eval(steps) - return emr_hook.add_job_flow_steps(job_flow_id=job_flow_id, steps=steps, wait_for_completion=True) + return emr_hook.add_job_flow_steps( + job_flow_id=job_flow_id, steps=steps, wait_for_completion=wait_for_completion + ) class EmrStartNotebookExecutionOperator(BaseOperator): diff --git a/tests/providers/amazon/aws/operators/test_emr_add_steps.py b/tests/providers/amazon/aws/operators/test_emr_add_steps.py index 5b5f51030b4f1..c088c3bc2e951 100644 --- a/tests/providers/amazon/aws/operators/test_emr_add_steps.py +++ b/tests/providers/amazon/aws/operators/test_emr_add_steps.py @@ -207,3 +207,21 @@ def test_init_with_nonexistent_cluster_name(self): with pytest.raises(AirflowException) as ctx: operator.execute(self.mock_context) assert str(ctx.value) == f"No cluster found for name: {cluster_name}" + + @patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.add_job_flow_steps") + def test_wait_for_completion(self, mock_add_job_flow_steps): + job_flow_id = "j-8989898989" + operator = EmrAddStepsOperator( + task_id="test_task", + job_flow_id=job_flow_id, + aws_conn_id="aws_default", + dag=DAG("test_dag_id", default_args=self.args), + wait_for_completion=False, + ) + operator.execute(self.mock_context) + + mock_add_job_flow_steps.assert_called_once_with( + job_flow_id=job_flow_id, + steps=[], + wait_for_completion=False, + )