From fda7d9f73c889562b8510b01cf8dc720e2ef27c6 Mon Sep 17 00:00:00 2001 From: Roman Sheludko Date: Mon, 19 Dec 2022 15:19:46 +0100 Subject: [PATCH 1/3] Add waiter config params to emr.add_job_flow_steps --- airflow/providers/amazon/aws/hooks/emr.py | 13 ++++++++++--- airflow/providers/amazon/aws/operators/emr.py | 12 +++++++++++- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index fcc9945711741..af5866ded65b1 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -126,7 +126,12 @@ def create_job_flow(self, job_flow_overrides: dict[str, Any]) -> dict[str, Any]: return response def add_job_flow_steps( - self, job_flow_id: str, steps: list[dict] | str | None = None, wait_for_completion: bool = False + self, + job_flow_id: str, + steps: list[dict] | str | None = None, + wait_for_completion: bool = False, + waiter_delay: int = 5, + waiter_max_attempts: int = 100, ) -> list[str]: """ Add new steps to a running cluster. @@ -134,6 +139,8 @@ def add_job_flow_steps( :param job_flow_id: The id of the job flow to which the steps are being added :param steps: A list of the steps to be executed by the job flow :param wait_for_completion: If True, wait for the steps to be completed. Default is False + :param waiter_delay: The amount of time in seconds to wait between attempts. Default is 5 + :param waiter_max_attempts: The maximum number of attempts to be made. Default is 100 """ response = self.get_conn().add_job_flow_steps(JobFlowId=job_flow_id, Steps=steps) @@ -148,8 +155,8 @@ def add_job_flow_steps( ClusterId=job_flow_id, StepId=step_id, WaiterConfig={ - "Delay": 5, - "MaxAttempts": 100, + "Delay": waiter_delay, + "MaxAttempts": waiter_max_attempts, }, ) return response["StepIds"] diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 3728994df3307..b606fcd6a2ee9 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -71,6 +71,8 @@ def __init__( aws_conn_id: str = "aws_default", steps: list[dict] | str | None = None, wait_for_completion: bool = False, + waiter_delay: int = 5, + waiter_max_attempts: int = 100, **kwargs, ): if not exactly_one(job_flow_id is None, job_flow_name is None): @@ -84,6 +86,8 @@ def __init__( self.cluster_states = cluster_states self.steps = steps self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts def execute(self, context: Context) -> list[str]: emr_hook = EmrHook(aws_conn_id=self.aws_conn_id) @@ -112,11 +116,17 @@ def execute(self, context: Context) -> list[str]: # e.g. if we used XCom or a file then: steps="[{ step1 }, { step2 }]" steps = self.steps wait_for_completion = self.wait_for_completion + waiter_delay = self.waiter_delay + waiter_max_attempts = self.waiter_max_attempts if isinstance(steps, str): steps = ast.literal_eval(steps) return emr_hook.add_job_flow_steps( - job_flow_id=job_flow_id, steps=steps, wait_for_completion=wait_for_completion + job_flow_id=job_flow_id, + steps=steps, + wait_for_completion=wait_for_completion, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, ) From 2dae32426dad97828c4cc72a82a1dbbe53b64197 Mon Sep 17 00:00:00 2001 From: Roman Sheludko Date: Mon, 19 Dec 2022 16:48:25 +0100 Subject: [PATCH 2/3] Fix checks --- airflow/providers/amazon/aws/hooks/emr.py | 2 +- airflow/providers/amazon/aws/operators/emr.py | 2 +- tests/providers/amazon/aws/operators/test_emr_add_steps.py | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index af5866ded65b1..b395726eea22e 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -131,7 +131,7 @@ def add_job_flow_steps( steps: list[dict] | str | None = None, wait_for_completion: bool = False, waiter_delay: int = 5, - waiter_max_attempts: int = 100, + waiter_max_attempts: int = 100, ) -> list[str]: """ Add new steps to a running cluster. diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index b606fcd6a2ee9..a0397e4caeb03 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -116,7 +116,7 @@ def execute(self, context: Context) -> list[str]: # e.g. if we used XCom or a file then: steps="[{ step1 }, { step2 }]" steps = self.steps wait_for_completion = self.wait_for_completion - waiter_delay = self.waiter_delay + waiter_delay = self.waiter_delay waiter_max_attempts = self.waiter_max_attempts if isinstance(steps, str): steps = ast.literal_eval(steps) diff --git a/tests/providers/amazon/aws/operators/test_emr_add_steps.py b/tests/providers/amazon/aws/operators/test_emr_add_steps.py index c088c3bc2e951..82d38d5d39b83 100644 --- a/tests/providers/amazon/aws/operators/test_emr_add_steps.py +++ b/tests/providers/amazon/aws/operators/test_emr_add_steps.py @@ -224,4 +224,6 @@ def test_wait_for_completion(self, mock_add_job_flow_steps): job_flow_id=job_flow_id, steps=[], wait_for_completion=False, + waiter_delay=5, + waiter_max_attempts=100, ) From fbcc396e9b3b74dd7fa61d4389aa74200c9ed9ac Mon Sep 17 00:00:00 2001 From: Roman Sheludko Date: Tue, 3 Jan 2023 08:21:08 +0100 Subject: [PATCH 3/3] Use default none values --- airflow/providers/amazon/aws/hooks/emr.py | 15 +++++++++------ airflow/providers/amazon/aws/operators/emr.py | 4 ++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index b395726eea22e..53ba8ab42bf97 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -28,6 +28,7 @@ from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.utils.waiter import get_state, waiter +from airflow.utils.helpers import prune_dict class EmrHook(AwsBaseHook): @@ -130,8 +131,8 @@ def add_job_flow_steps( job_flow_id: str, steps: list[dict] | str | None = None, wait_for_completion: bool = False, - waiter_delay: int = 5, - waiter_max_attempts: int = 100, + waiter_delay: int | None = None, + waiter_max_attempts: int | None = None, ) -> list[str]: """ Add new steps to a running cluster. @@ -154,10 +155,12 @@ def add_job_flow_steps( waiter.wait( ClusterId=job_flow_id, StepId=step_id, - WaiterConfig={ - "Delay": waiter_delay, - "MaxAttempts": waiter_max_attempts, - }, + WaiterConfig=prune_dict( + { + "Delay": waiter_delay, + "MaxAttempts": waiter_max_attempts, + } + ), ) return response["StepIds"] diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index a0397e4caeb03..88c6f34d30c7a 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -71,8 +71,8 @@ def __init__( aws_conn_id: str = "aws_default", steps: list[dict] | str | None = None, wait_for_completion: bool = False, - waiter_delay: int = 5, - waiter_max_attempts: int = 100, + waiter_delay: int | None = None, + waiter_max_attempts: int | None = None, **kwargs, ): if not exactly_one(job_flow_id is None, job_flow_name is None):