From de093760be249afb67a7f60ab7e0936540a5af45 Mon Sep 17 00:00:00 2001 From: Adrian Lazar Date: Tue, 19 Nov 2024 10:34:29 +0200 Subject: [PATCH] Add wait_policy option to EmrCreateJobFlowOperator. Possible values: - None: No wait (default) - WaitPolicy.WAIT_FOR_COMPLETION: Previous behaviour when wait_for_completion was True - WaitPolicy.WAIT_FOR_STEPS_COMPLETION: New behaviour - wait for the cluster to terminate. --- .../operators/emr/emr.rst | 9 +++++ .../providers/amazon/aws/operators/emr.py | 36 +++++++++++++++---- .../providers/amazon/aws/utils/waiter.py | 20 +++++++++++ .../aws/operators/test_emr_create_job_flow.py | 14 ++++++-- 4 files changed, 69 insertions(+), 10 deletions(-) diff --git a/docs/apache-airflow-providers-amazon/operators/emr/emr.rst b/docs/apache-airflow-providers-amazon/operators/emr/emr.rst index 5e32baa151c4f..a92837eac3526 100644 --- a/docs/apache-airflow-providers-amazon/operators/emr/emr.rst +++ b/docs/apache-airflow-providers-amazon/operators/emr/emr.rst @@ -47,6 +47,15 @@ Create an EMR job flow You can use :class:`~airflow.providers.amazon.aws.operators.emr.EmrCreateJobFlowOperator` to create a new EMR job flow. The cluster will be terminated automatically after finishing the steps. + +The default behaviour is to mark the DAG Task node as success as soon as the cluster is launched +(``wait_policy=None``). +It is possible to modify this behaviour by using a different ``wait_policy``. Available options are: + +- ``WaitPolicy.WAIT_FOR_COMPLETION`` - DAG Task node waits for the cluster to be running +- ``WaitPolicy.WAIT_FOR_STEPS_COMPLETION`` - DAG Task node waits for the cluster to terminate + + This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. Using ``deferrable`` mode will release worker slots and leads to efficient utilization of resources within Airflow cluster.However this mode will need the Airflow triggerer to be diff --git a/providers/src/airflow/providers/amazon/aws/operators/emr.py b/providers/src/airflow/providers/amazon/aws/operators/emr.py index c6ba3bf29b857..d7680baf61a7f 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/emr.py +++ b/providers/src/airflow/providers/amazon/aws/operators/emr.py @@ -18,6 +18,7 @@ from __future__ import annotations import ast +import warnings from collections.abc import Sequence from datetime import timedelta from functools import cached_property @@ -25,7 +26,7 @@ from uuid import uuid4 from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.links.emr import ( @@ -50,7 +51,11 @@ EmrTerminateJobFlowTrigger, ) from airflow.providers.amazon.aws.utils import validate_execute_complete_event -from airflow.providers.amazon.aws.utils.waiter import waiter +from airflow.providers.amazon.aws.utils.waiter import ( + WAITER_POLICY_NAME_MAPPING, + WaitPolicy, + waiter, +) from airflow.providers.amazon.aws.utils.waiter_with_logging import wait from airflow.utils.helpers import exactly_one, prune_dict from airflow.utils.types import NOTSET, ArgNotSet @@ -637,8 +642,14 @@ class EmrCreateJobFlowOperator(BaseOperator): :param job_flow_overrides: boto3 style arguments or reference to an arguments file (must be '.json') to override specific ``emr_conn_id`` extra parameters. (templated) :param region_name: Region named passed to EmrHook - :param wait_for_completion: Whether to finish task immediately after creation (False) or wait for jobflow + :param wait_for_completion: Deprecated - use `wait_policy` instead. + Whether to finish task immediately after creation (False) or wait for jobflow completion (True) + (default: None) + :param wait_policy: Whether to finish the task immediately after creation (None) or: + - wait for the jobflow completion (WaitPolicy.WAIT_FOR_COMPLETION) + - wait for the jobflow completion and cluster to terminate (WaitPolicy.WAIT_FOR_STEPS_COMPLETION) + (default: None) :param waiter_max_attempts: Maximum number of tries before failing. :param waiter_delay: Number of seconds between polling the state of the notebook. :param deferrable: If True, the operator will wait asynchronously for the crawl to complete. @@ -666,7 +677,8 @@ def __init__( emr_conn_id: str | None = "emr_default", job_flow_overrides: str | dict[str, Any] | None = None, region_name: str | None = None, - wait_for_completion: bool = False, + wait_for_completion: bool | None = None, + wait_policy: WaitPolicy | None = None, waiter_max_attempts: int | None = None, waiter_delay: int | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), @@ -677,11 +689,20 @@ def __init__( self.emr_conn_id = emr_conn_id self.job_flow_overrides = job_flow_overrides or {} self.region_name = region_name - self.wait_for_completion = wait_for_completion + self.wait_policy = wait_policy self.waiter_max_attempts = waiter_max_attempts or 60 self.waiter_delay = waiter_delay or 60 self.deferrable = deferrable + if wait_for_completion is not None: + warnings.warn( + "`wait_for_completion` parameter is deprecated, please use `wait_policy` instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + # preserve previous behaviour + self.wait_policy = WaitPolicy.WAIT_FOR_COMPLETION if wait_for_completion else None + @cached_property def _emr_hook(self) -> EmrHook: """Create and return an EmrHook.""" @@ -734,8 +755,9 @@ def execute(self, context: Context) -> str | None: # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60), ) - if self.wait_for_completion: - self._emr_hook.get_waiter("job_flow_waiting").wait( + if self.wait_policy: + waiter_name = WAITER_POLICY_NAME_MAPPING[self.wait_policy] + self._emr_hook.get_waiter(waiter_name).wait( ClusterId=self._job_flow_id, WaiterConfig=prune_dict( { diff --git a/providers/src/airflow/providers/amazon/aws/utils/waiter.py b/providers/src/airflow/providers/amazon/aws/utils/waiter.py index b096c30203d03..72e03cec5d619 100644 --- a/providers/src/airflow/providers/amazon/aws/utils/waiter.py +++ b/providers/src/airflow/providers/amazon/aws/utils/waiter.py @@ -19,6 +19,7 @@ import logging import time +from enum import Enum from typing import Callable from airflow.exceptions import AirflowException @@ -83,3 +84,22 @@ def get_state(response, keys) -> str: if value is not None: value = value.get(key, None) return value + + +class WaitPolicy(str, Enum): + """ + Used to control the waiting behaviour within EMRClusterJobFlowOperator. + + Choices: + - WAIT_FOR_COMPLETION - Will wait for the cluster to report "Running" state + - WAIT_FOR_STEPS_COMPLETION - Will wait for the cluster to report "Terminated" state + """ + + WAIT_FOR_COMPLETION = "wait_for_completion" + WAIT_FOR_STEPS_COMPLETION = "wait_for_steps_completion" + + +WAITER_POLICY_NAME_MAPPING: dict[WaitPolicy, str] = { + WaitPolicy.WAIT_FOR_COMPLETION: "job_flow_waiting", + WaitPolicy.WAIT_FOR_STEPS_COMPLETION: "job_flow_terminated", +} diff --git a/providers/tests/amazon/aws/operators/test_emr_create_job_flow.py b/providers/tests/amazon/aws/operators/test_emr_create_job_flow.py index a17304fc64416..2bd6fa444f569 100644 --- a/providers/tests/amazon/aws/operators/test_emr_create_job_flow.py +++ b/providers/tests/amazon/aws/operators/test_emr_create_job_flow.py @@ -30,6 +30,7 @@ from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger +from airflow.providers.amazon.aws.utils.waiter import WAITER_POLICY_NAME_MAPPING, WaitPolicy from airflow.utils import timezone from airflow.utils.types import DagRunType @@ -193,17 +194,24 @@ def test_execute_returns_job_id(self, mocked_hook_client): mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN assert self.operator.execute(self.mock_context) == JOB_FLOW_ID + @pytest.mark.parametrize( + "wait_policy", + [ + pytest.param(WaitPolicy.WAIT_FOR_COMPLETION, id="with wait for completion"), + pytest.param(WaitPolicy.WAIT_FOR_STEPS_COMPLETION, id="with wait for steps completion policy"), + ], + ) @mock.patch("botocore.waiter.get_service_module_name", return_value="emr") @mock.patch.object(Waiter, "wait") - def test_execute_with_wait(self, mock_waiter, _, mocked_hook_client): + def test_execute_with_wait_policy(self, mock_waiter, _, mocked_hook_client, wait_policy: WaitPolicy): mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN # Mock out the emr_client creator - self.operator.wait_for_completion = True + self.operator.wait_policy = wait_policy assert self.operator.execute(self.mock_context) == JOB_FLOW_ID mock_waiter.assert_called_once_with(mock.ANY, ClusterId=JOB_FLOW_ID, WaiterConfig=mock.ANY) - assert_expected_waiter_type(mock_waiter, "job_flow_waiting") + assert_expected_waiter_type(mock_waiter, WAITER_POLICY_NAME_MAPPING[wait_policy]) def test_create_job_flow_deferrable(self, mocked_hook_client): """