diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index 5ad2d241f3ce8..2b553a4f8851a 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -28,6 +28,7 @@ from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.utils.waiter import get_state, waiter +from airflow.utils.helpers import prune_dict class EmrHook(AwsBaseHook): @@ -130,6 +131,8 @@ def add_job_flow_steps( job_flow_id: str, steps: list[dict] | str | None = None, wait_for_completion: bool = False, + waiter_delay: int | None = None, + waiter_max_attempts: int | None = None, execution_role_arn: str | None = None, ) -> list[str]: """ @@ -138,6 +141,8 @@ def add_job_flow_steps( :param job_flow_id: The id of the job flow to which the steps are being added :param steps: A list of the steps to be executed by the job flow :param wait_for_completion: If True, wait for the steps to be completed. Default is False + :param waiter_delay: The amount of time in seconds to wait between attempts. Default is 5 + :param waiter_max_attempts: The maximum number of attempts to be made. Default is 100 :param execution_role_arn: The ARN of the runtime role for a step on the cluster. """ config = {} @@ -155,10 +160,12 @@ def add_job_flow_steps( waiter.wait( ClusterId=job_flow_id, StepId=step_id, - WaiterConfig={ - "Delay": 5, - "MaxAttempts": 100, - }, + WaiterConfig=prune_dict( + { + "Delay": waiter_delay, + "MaxAttempts": waiter_max_attempts, + } + ), ) return response["StepIds"] diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 7bc81e3d4d688..ab3de9d2bdccb 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -78,6 +78,8 @@ def __init__( aws_conn_id: str = "aws_default", steps: list[dict] | str | None = None, wait_for_completion: bool = False, + waiter_delay: int | None = None, + waiter_max_attempts: int | None = None, execution_role_arn: str | None = None, **kwargs, ): @@ -92,6 +94,8 @@ def __init__( self.cluster_states = cluster_states self.steps = steps self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts self.execution_role_arn = execution_role_arn def execute(self, context: Context) -> list[str]: @@ -126,6 +130,8 @@ def execute(self, context: Context) -> list[str]: job_flow_id=job_flow_id, steps=steps, wait_for_completion=self.wait_for_completion, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, execution_role_arn=self.execution_role_arn, ) diff --git a/tests/providers/amazon/aws/operators/test_emr_add_steps.py b/tests/providers/amazon/aws/operators/test_emr_add_steps.py index 7ac459f398200..6f9c1c1b45922 100644 --- a/tests/providers/amazon/aws/operators/test_emr_add_steps.py +++ b/tests/providers/amazon/aws/operators/test_emr_add_steps.py @@ -224,5 +224,7 @@ def test_wait_for_completion(self, mock_add_job_flow_steps): job_flow_id=job_flow_id, steps=[], wait_for_completion=False, + waiter_delay=None, + waiter_max_attempts=None, execution_role_arn=None, )