diff --git a/airflow/providers/amazon/aws/example_dags/example_emr_serverless.py b/airflow/providers/amazon/aws/example_dags/example_emr_serverless.py new file mode 100644 index 0000000000000..b8c0618014808 --- /dev/null +++ b/airflow/providers/amazon/aws/example_dags/example_emr_serverless.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from os import getenv + +from airflow import DAG +from airflow.models.baseoperator import chain +from airflow.providers.amazon.aws.operators.emr import ( + EmrServerlessCreateApplicationOperator, + EmrServerlessDeleteApplicationOperator, + EmrServerlessStartJobOperator, +) +from airflow.providers.amazon.aws.sensors.emr import EmrServerlessApplicationSensor, EmrServerlessJobSensor + +EXECUTION_ROLE_ARN = getenv('EXECUTION_ROLE_ARN', 'execution_role_arn') +EMR_EXAMPLE_BUCKET = getenv('EMR_EXAMPLE_BUCKET', 'emr_example_bucket') +SPARK_JOB_DRIVER = { + "sparkSubmit": { + "entryPoint": "s3://us-east-1.elasticmapreduce/emr-containers/samples/wordcount/scripts/wordcount.py", + "entryPointArguments": [f"s3://{EMR_EXAMPLE_BUCKET}/output"], + "sparkSubmitParameters": "--conf spark.executor.cores=1 --conf spark.executor.memory=4g\ + --conf spark.driver.cores=1 --conf spark.driver.memory=4g --conf spark.executor.instances=1", + } +} + +SPARK_CONFIGURATION_OVERRIDES = { + "monitoringConfiguration": {"s3MonitoringConfiguration": {"logUri": f"s3://{EMR_EXAMPLE_BUCKET}/logs"}} +} + +with DAG( + dag_id='example_emr_serverless', + schedule_interval=None, + start_date=datetime(2021, 1, 1), + tags=['example'], + catchup=False, +) as emr_serverless_dag: + + # [START howto_operator_emr_serverless_create_application] + emr_serverless_app = EmrServerlessCreateApplicationOperator( + task_id='create_emr_serverless_task', + release_label='emr-6.6.0', + job_type="SPARK", + config={'name': 'new_application'}, + ) + # [END howto_operator_emr_serverless_create_application] + + # [START howto_sensor_emr_serverless_application] + wait_for_app_creation = EmrServerlessApplicationSensor( + task_id='wait_for_app_creation', + application_id=emr_serverless_app.output, + ) + # [END howto_sensor_emr_serverless_application] + + # [START howto_operator_emr_serverless_start_job] + start_job = EmrServerlessStartJobOperator( + task_id='start_emr_serverless_job', + application_id=emr_serverless_app.output, + execution_role_arn=EXECUTION_ROLE_ARN, + job_driver=SPARK_JOB_DRIVER, + configuration_overrides=SPARK_CONFIGURATION_OVERRIDES, + ) + # [END howto_operator_emr_serverless_start_job] + + # [START howto_sensor_emr_serverless_job] + wait_for_job = EmrServerlessJobSensor( + task_id='wait_for_job', application_id=emr_serverless_app.output, job_run_id=start_job.output + ) + # [END howto_sensor_emr_serverless_job] + + # [START howto_operator_emr_serverless_delete_application] + delete_app = EmrServerlessDeleteApplicationOperator( + task_id='delete_application', application_id=emr_serverless_app.output, trigger_rule="all_done" + ) + # [END howto_operator_emr_serverless_delete_application] + + chain( + emr_serverless_app, + wait_for_app_creation, + start_job, + wait_for_job, + delete_app, + ) diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index 2141b38ed8fc4..e085bff8999e1 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -16,10 +16,11 @@ # specific language governing permissions and limitations # under the License. from time import sleep -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Set from botocore.exceptions import ClientError +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -90,6 +91,78 @@ def create_job_flow(self, job_flow_overrides: Dict[str, Any]) -> Dict[str, Any]: return response +class EmrServerlessHook(AwsBaseHook): + """ + Interact with EMR Serverless API. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + kwargs["client_type"] = "emr-serverless" + super().__init__(*args, **kwargs) + + @cached_property + def conn(self): + """Get the underlying boto3 EmrServerlessAPIService client (cached)""" + return super().conn + + # This method should be replaced with boto waiters which would implement timeouts and backoff nicely. + def waiter( + self, + get_state_callable: Callable, + get_state_args: Dict, + parse_response: List, + desired_state: Set, + failure_states: Set, + object_type: str, + action: str, + countdown: int = 25 * 60, + check_interval_seconds: int = 60, + ) -> None: + """ + Will run the sensor until it turns True. + + :param get_state_callable: A callable to run until it returns True + :param get_state_args: Arguments to pass to get_state_callable + :param parse_response: Dictionary keys to extract state from response of get_state_callable + :param desired_state: Wait until the getter returns this value + :param failure_states: A set of states which indicate failure and should throw an + exception if any are reached before the desired_state + :param object_type: Used for the reporting string. What are you waiting for? (application, job, etc) + :param action: Used for the reporting string. What action are you waiting for? (created, deleted, etc) + :param countdown: Total amount of time the waiter should wait for the desired state + before timing out (in seconds). Defaults to 25 * 60 seconds. + :param check_interval_seconds: Number of seconds waiter should wait before attempting + to retry get_state_callable. Defaults to 60 seconds. + """ + response = get_state_callable(**get_state_args) + state: str = self.get_state(response, parse_response) + while state not in desired_state: + if state in failure_states: + raise AirflowException(f'{object_type.title()} reached failure state {state}.') + if countdown >= check_interval_seconds: + countdown -= check_interval_seconds + self.log.info('Waiting for %s to be %s.', object_type.lower(), action.lower()) + sleep(check_interval_seconds) + state = self.get_state(get_state_callable(**get_state_args), parse_response) + else: + message = f'{object_type.title()} still not {action.lower()} after the allocated time limit.' + self.log.error(message) + raise RuntimeError(message) + + def get_state(self, response, keys) -> str: + value = response + for key in keys: + if value is not None: + value = value.get(key, None) + return value + + class EmrContainerHook(AwsBaseHook): """ Interact with AWS EMR Virtual Cluster to run, poll jobs and return job status diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 4e8a1d96c9895..1cac1eb0f5d48 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -19,15 +19,17 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from uuid import uuid4 -from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.models import BaseOperator -from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook +from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.links.emr import EmrClusterLink +from airflow.providers.amazon.aws.sensors.emr import EmrServerlessApplicationSensor, EmrServerlessJobSensor if TYPE_CHECKING: from airflow.utils.context import Context +from airflow.compat.functools import cached_property + class EmrAddStepsOperator(BaseOperator): """ @@ -412,3 +414,259 @@ def execute(self, context: 'Context') -> None: raise AirflowException(f'JobFlow termination failed: {response}') else: self.log.info('JobFlow with id %s terminated', self.job_flow_id) + + +class EmrServerlessCreateApplicationOperator(BaseOperator): + """ + Operator to create Serverless EMR Application + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EmrServerlessCreateApplicationOperator` + + :param release_label: The EMR release version associated with the application. + :param job_type: The type of application you want to start, such as Spark or Hive. + :param wait_for_completion: If true, wait for the Application to start before returning. Default to True + :param client_request_token: The client idempotency token of the application to create. + Its value must be unique for each request. + :param config: Optional dictionary for arbitrary parameters to the boto API create_application call. + :param aws_conn_id: AWS connection to use + """ + + def __init__( + self, + release_label: str, + job_type: str, + client_request_token: str = '', + config: Optional[dict] = None, + wait_for_completion: bool = True, + aws_conn_id: str = 'aws_default', + **kwargs, + ): + self.aws_conn_id = aws_conn_id + self.release_label = release_label + self.job_type = job_type + self.wait_for_completion = wait_for_completion + self.kwargs = kwargs + self.config = config or {} + super().__init__(**kwargs) + + self.client_request_token = client_request_token or str(uuid4()) + + @cached_property + def hook(self) -> EmrServerlessHook: + """Create and return an EmrServerlessHook.""" + return EmrServerlessHook(aws_conn_id=self.aws_conn_id) + + def execute(self, context: 'Context'): + response = self.hook.conn.create_application( + clientToken=self.client_request_token, + releaseLabel=self.release_label, + type=self.job_type, + **self.config, + ) + application_id = response['applicationId'] + + if response['ResponseMetadata']['HTTPStatusCode'] != 200: + raise AirflowException(f'Application Creation failed: {response}') + + self.log.info('EMR serverless application created: %s', application_id) + + # This should be replaced with a boto waiter when available. + self.hook.waiter( + get_state_callable=self.hook.conn.get_application, + get_state_args={'applicationId': application_id}, + parse_response=['application', 'state'], + desired_state={'CREATED'}, + failure_states=EmrServerlessApplicationSensor.FAILURE_STATES, + object_type='application', + action='created', + ) + + self.log.info('Starting application %s', application_id) + self.hook.conn.start_application(applicationId=application_id) + + if self.wait_for_completion: + # This should be replaced with a boto waiter when available. + self.hook.waiter( + get_state_callable=self.hook.conn.get_application, + get_state_args={'applicationId': application_id}, + parse_response=['application', 'state'], + desired_state={'STARTED'}, + failure_states=EmrServerlessApplicationSensor.FAILURE_STATES, + object_type='application', + action='started', + ) + + return application_id + + +class EmrServerlessStartJobOperator(BaseOperator): + """ + Operator to start EMR Serverless job. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EmrServerlessStartJobOperator` + + :param application_id: ID of the EMR Serverless application to start. + :param execution_role_arn: ARN of role to perform action. + :param job_driver: Driver that the job runs on. + :param configuration_overrides: Configuration specifications to override existing configurations. + :param client_request_token: The client idempotency token of the application to create. + Its value must be unique for each request. + :param config: Optional dictionary for arbitrary parameters to the boto API start_job_run call. + :param wait_for_completion: If true, waits for the job to start before returning. Defaults to True. + :param aws_conn_id: AWS connection to use + """ + + template_fields: Sequence[str] = ( + 'application_id', + 'execution_role_arn', + 'job_driver', + 'configuration_overrides', + ) + + def __init__( + self, + application_id: str, + execution_role_arn: str, + job_driver: dict, + configuration_overrides: Optional[dict], + client_request_token: str = '', + config: Optional[dict] = None, + wait_for_completion: bool = True, + aws_conn_id: str = 'aws_default', + **kwargs, + ): + self.aws_conn_id = aws_conn_id + self.application_id = application_id + self.execution_role_arn = execution_role_arn + self.job_driver = job_driver + self.configuration_overrides = configuration_overrides + self.wait_for_completion = wait_for_completion + self.config = config or {} + super().__init__(**kwargs) + + self.client_request_token = client_request_token or str(uuid4()) + + @cached_property + def hook(self) -> EmrServerlessHook: + """Create and return an EmrServerlessHook.""" + return EmrServerlessHook(aws_conn_id=self.aws_conn_id) + + def execute(self, context: 'Context') -> Dict: + self.log.info('Starting job on Application: %s', self.application_id) + + app_state = self.hook.conn.get_application(applicationId=self.application_id)['application']['state'] + if app_state not in EmrServerlessApplicationSensor.SUCCESS_STATES: + self.hook.conn.start_application(applicationId=self.application_id) + + self.hook.waiter( + get_state_callable=self.hook.conn.get_application, + get_state_args={'applicationId': self.application_id}, + parse_response=['application', 'state'], + desired_state={'STARTED'}, + failure_states=EmrServerlessApplicationSensor.FAILURE_STATES, + object_type='application', + action='started', + ) + + response = self.hook.conn.start_job_run( + clientToken=self.client_request_token, + applicationId=self.application_id, + executionRoleArn=self.execution_role_arn, + jobDriver=self.job_driver, + configurationOverrides=self.configuration_overrides, + **self.config, + ) + + if response['ResponseMetadata']['HTTPStatusCode'] != 200: + raise AirflowException(f'EMR serverless job failed to start: {response}') + + self.log.info('EMR serverless job started: %s', response['jobRunId']) + if self.wait_for_completion: + # This should be replaced with a boto waiter when available. + self.hook.waiter( + get_state_callable=self.hook.conn.get_job_run, + get_state_args={ + 'applicationId': self.application_id, + 'jobRunId': response['jobRunId'], + }, + parse_response=['jobRun', 'state'], + desired_state=EmrServerlessJobSensor.TERMINAL_STATES, + failure_states=EmrServerlessJobSensor.FAILURE_STATES, + object_type='job', + action='run', + ) + return response['jobRunId'] + + +class EmrServerlessDeleteApplicationOperator(BaseOperator): + """ + Operator to delete EMR Serverless application + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EmrServerlessDeleteApplicationOperator` + + :param application_id: ID of the EMR Serverless application to delete. + :param wait_for_completion: If true, wait for the Application to start before returning. Default to True + :param aws_conn_id: AWS connection to use + """ + + template_fields: Sequence[str] = ('application_id',) + + def __init__( + self, + application_id: str, + wait_for_completion: bool = True, + aws_conn_id: str = 'aws_default', + **kwargs, + ): + self.aws_conn_id = aws_conn_id + self.application_id = application_id + self.wait_for_completion = wait_for_completion + super().__init__(**kwargs) + + @cached_property + def hook(self) -> EmrServerlessHook: + """Create and return an EmrServerlessHook.""" + return EmrServerlessHook(aws_conn_id=self.aws_conn_id) + + def execute(self, context: 'Context') -> None: + self.log.info('Stopping application: %s', self.application_id) + self.hook.conn.stop_application(applicationId=self.application_id) + + # This should be replaced with a boto waiter when available. + self.hook.waiter( + get_state_callable=self.hook.conn.get_application, + get_state_args={ + 'applicationId': self.application_id, + }, + parse_response=['application', 'state'], + desired_state=EmrServerlessApplicationSensor.FAILURE_STATES, + failure_states=set(), + object_type='application', + action='stopped', + ) + + self.log.info('Deleting application: %s', self.application_id) + response = self.hook.conn.delete_application(applicationId=self.application_id) + + if response['ResponseMetadata']['HTTPStatusCode'] != 200: + raise AirflowException(f'Application deletion failed: {response}') + + if self.wait_for_completion: + # This should be replaced with a boto waiter when available. + self.hook.waiter( + get_state_callable=self.hook.conn.get_application, + get_state_args={'applicationId': self.application_id}, + parse_response=['application', 'state'], + desired_state={'TERMINATED'}, + failure_states=EmrServerlessApplicationSensor.FAILURE_STATES, + object_type='application', + action='deleted', + ) + + self.log.info('EMR serverless application deleted') diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index 62c74ea560934..7c09e2fb899ff 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -15,15 +15,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, FrozenSet, Iterable, Optional, Sequence, Set, Union + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook +from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: from airflow.utils.context import Context from airflow.compat.functools import cached_property -from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook -from airflow.sensors.base import BaseSensorOperator class EmrBaseSensor(BaseSensorOperator): @@ -37,7 +38,7 @@ class EmrBaseSensor(BaseSensorOperator): Subclasses should set ``target_states`` and ``failed_states`` fields. - :param aws_conn_id: aws connection to uses + :param aws_conn_id: aws connection to use """ ui_color = '#66c3ff' @@ -111,6 +112,137 @@ def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]: raise NotImplementedError('Please implement failure_message_from_response() in subclass') +class EmrServerlessJobSensor(BaseSensorOperator): + """ + Asks for the state of the job run until it reaches a failure state or success state. + If the job run fails, the task will fail. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:EmrServerlessJobSensor` + + :param application_id: application_id to check the state of + :param job_run_id: job_run_id to check the state of + :param target_states: a set of states to wait for, defaults to 'SUCCESS' + :param aws_conn_id: aws connection to use, defaults to 'aws_default' + """ + + INTERMEDIATE_STATES = {'PENDING', 'RUNNING', 'SCHEDULED', 'SUBMITTED'} + FAILURE_STATES = {'FAILED', 'CANCELLING', 'CANCELLED'} + SUCCESS_STATES = {'SUCCESS'} + TERMINAL_STATES = SUCCESS_STATES.union(FAILURE_STATES) + + template_fields: Sequence[str] = ( + 'application_id', + 'job_run_id', + ) + + def __init__( + self, + *, + application_id: str, + job_run_id: str, + target_states: Union[Set, FrozenSet] = frozenset(SUCCESS_STATES), + aws_conn_id: str = 'aws_default', + **kwargs: Any, + ) -> None: + self.aws_conn_id = aws_conn_id + self.target_states = target_states + self.application_id = application_id + self.job_run_id = job_run_id + super().__init__(**kwargs) + + def poke(self, context: 'Context') -> bool: + response = self.hook.conn.get_job_run(applicationId=self.application_id, jobRunId=self.job_run_id) + + state = response['jobRun']['state'] + + if state in self.FAILURE_STATES: + failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}" + raise AirflowException(failure_message) + + return state in self.target_states + + @cached_property + def hook(self) -> EmrServerlessHook: + """Create and return an EmrServerlessHook""" + return EmrServerlessHook(aws_conn_id=self.aws_conn_id) + + @staticmethod + def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]: + """ + Get failure message from response dictionary. + + :param response: response from AWS API + :return: failure message + :rtype: Optional[str] + """ + return response['jobRun']['stateDetails'] + + +class EmrServerlessApplicationSensor(BaseSensorOperator): + """ + Asks for the state of the application until it reaches a failure state or success state. + If the application fails, the task will fail. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:EmrServerlessApplicationSensor` + + :param application_id: application_id to check the state of + :param target_states: a set of states to wait for, defaults to {'CREATED', 'STARTED'} + :param aws_conn_id: aws connection to use, defaults to 'aws_default' + """ + + template_fields: Sequence[str] = ('application_id',) + + INTERMEDIATE_STATES = {'CREATING', 'STARTING', 'STOPPING'} + FAILURE_STATES = {'STOPPED', 'TERMINATED'} + SUCCESS_STATES = {'CREATED', 'STARTED'} + + def __init__( + self, + *, + application_id: str, + target_states: Union[Set, FrozenSet] = frozenset(SUCCESS_STATES), + aws_conn_id: str = 'aws_default', + **kwargs: Any, + ) -> None: + self.aws_conn_id = aws_conn_id + self.target_states = target_states + self.application_id = application_id + super().__init__(**kwargs) + + def poke(self, context: 'Context') -> bool: + state = None + + response = self.hook.conn.get_application(applicationId=self.application_id) + + state = response['application']['state'] + + if state in self.FAILURE_STATES: + failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}" + raise AirflowException(failure_message) + + return state in self.target_states + + @cached_property + def hook(self) -> EmrServerlessHook: + """Create and return an EmrServerlessHook""" + return EmrServerlessHook(aws_conn_id=self.aws_conn_id) + + @staticmethod + def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]: + """ + Get failure message from response dictionary. + + :param response: response from AWS API + :return: failure message + :rtype: Optional[str] + """ + return response['application']['stateDetails'] + + class EmrContainerSensor(BaseSensorOperator): """ Asks for the state of the job run until it reaches a failure state or success state. diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 29921cc2d2d66..8e1973d6cd750 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -117,6 +117,12 @@ integrations: - /docs/apache-airflow-providers-amazon/operators/emr_eks.rst logo: /integration-logos/aws/Amazon-EMR_light-bg@4x.png tags: [aws] + - integration-name: Amazon EMR Serverless + external-doc-url: https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/emr-serverless.html + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/emr_serverless.rst + logo: /integration-logos/aws/Amazon-EMR_light-bg@4x.png + tags: [aws] - integration-name: Amazon Glacier external-doc-url: https://aws.amazon.com/glacier/ logo: /integration-logos/aws/Amazon-S3-Glacier_light-bg@4x.png diff --git a/docs/apache-airflow-providers-amazon/operators/emr_serverless.rst b/docs/apache-airflow-providers-amazon/operators/emr_serverless.rst new file mode 100644 index 0000000000000..2496af2c402f4 --- /dev/null +++ b/docs/apache-airflow-providers-amazon/operators/emr_serverless.rst @@ -0,0 +1,113 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +=============================== +Amazon EMR Serverless Operators +=============================== + +`Amazon EMR Serverless `__ is a serverless option +in Amazon EMR that makes it easy for data analysts and engineers to run open-source big +data analytics frameworks without configuring, managing, and scaling clusters or servers. +You get all the features and benefits of Amazon EMR without the need for experts to plan +and manage clusters. + +Prerequisite Tasks +------------------ + +.. include:: _partials/prerequisite_tasks.rst + +Operators +--------- +.. _howto/operator:EmrServerlessCreateApplicationOperator: + +Create an EMR Serverless Application +==================================== + +You can use :class:`~airflow.providers.amazon.aws.operators.emr.EmrServerlessCreateApplicationOperator` to +create a new EMR Serverless Application. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_emr_serverless.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_emr_serverless_create_application] + :end-before: [END howto_operator_emr_serverless_create_application] + +.. _howto/operator:EmrServerlessStartJobOperator: + +Start an EMR Serverless Job +============================ + +You can use :class:`~airflow.providers.amazon.aws.operators.emr.EmrServerlessStartJobOperator` to +start an EMR Serverless Job. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_emr_serverless.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_emr_serverless_start_job] + :end-before: [END howto_operator_emr_serverless_start_job] + +.. _howto/operator:EmrServerlessDeleteApplicationOperator: + +Delete an EMR Serverless Application +==================================== + +You can use :class:`~airflow.providers.amazon.aws.operators.emr.EmrServerlessDeleteApplicationOperator` to +delete an EMR Serverless Application. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_emr_serverless.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_emr_serverless_delete_application] + :end-before: [END howto_operator_emr_serverless_delete_application] + +Sensors +------- + +.. _howto/sensor:EmrServerlessJobSensor: + +Wait on an EMR Serverless Job state +=================================== + +To monitor the state of an EMR Serverless Job you can use +:class:`~airflow.providers.amazon.aws.sensors.emr.EmrServerlessJobSensor`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_emr_serverless.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_emr_serverless_job] + :end-before: [END howto_sensor_emr_serverless_job] + +.. _howto/sensor:EmrServerlessApplicationSensor: + +Wait on an EMR Serverless Application state +============================================ + +To monitor the state of an EMR Serverless Application you can use +:class:`~airflow.providers.amazon.aws.sensors.emr.EmrServerlessApplicationSensor`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_emr_serverless.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_emr_serverless_application] + :end-before: [END howto_sensor_emr_serverless_application] + +Reference +--------- + +* `AWS boto3 library documentation for EMR Serverless `__ +* `Configure IAM Roles for EMR Serverless permissions `__ diff --git a/tests/providers/amazon/aws/hooks/test_emr_serverless.py b/tests/providers/amazon/aws/hooks/test_emr_serverless.py new file mode 100644 index 0000000000000..db17da63dc988 --- /dev/null +++ b/tests/providers/amazon/aws/hooks/test_emr_serverless.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from unittest import mock + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook + +task_id = 'test_emr_serverless_create_application_operator' +application_id = 'test_application_id' +release_label = 'test' +job_type = 'test' +client_request_token = 'eac427d0-1c6d4df=-96aa-32423412' +config = {'name': 'test_application_emr_serverless'} + + +class TestEmrServerlessHook: + def test_conn_attribute(self): + hook = EmrServerlessHook(aws_conn_id='aws_default') + assert hasattr(hook, 'conn') + # Testing conn is a cached property + conn = hook.conn + conn2 = hook.conn + assert conn is conn2 + + def test_waiter_failure_then_success(self): + mock_call_function = mock.MagicMock() + mock_call_function.side_effect = [{'response': 'test_failure'}, {'response': 'test_success'}] + success_state = {'test_success'} + hook = EmrServerlessHook() + waiter_response = hook.waiter( + get_state_callable=mock_call_function, + get_state_args={}, + parse_response=['response'], + desired_state=success_state, + failure_states={}, + object_type='test_object', + action='testing', + check_interval_seconds=1, + ) + assert mock_call_function.call_count == 2 + assert waiter_response is None + + def test_waiter_success_state(self): + mock_call_function = mock.MagicMock() + mock_call_function.return_value = {'response': 'test_success'} + success_state = {'test_success'} + hook = EmrServerlessHook() + waiter_response = hook.waiter( + get_state_callable=mock_call_function, + get_state_args={}, + parse_response=['response'], + desired_state=success_state, + failure_states={}, + object_type='test_object', + action='testing', + ) + mock_call_function.assert_called_once() + assert waiter_response is None + + def test_waiter_failure_state(self): + mock_call_function = mock.MagicMock() + failure_state = {'test_failure'} + mock_call_function.return_value = {'response': 'test_failure'} + hook = EmrServerlessHook() + with pytest.raises(AirflowException) as ex_message: + hook.waiter( + get_state_callable=mock_call_function, + get_state_args={}, + parse_response=['response'], + desired_state={}, + failure_states=failure_state, + object_type='test_object', + action='testing', + ) + mock_call_function.assert_called_once() + assert str(ex_message.value) == f"Test_Object reached failure state {','.join(failure_state)}." + + def test_nested_waiter_success_state(self): + mock_call_function = mock.MagicMock() + mock_call_function.return_value = { + 'layer1': {'key1': 'value1', 'layer2': {'response': 'test_success'}} + } + success_state = {'test_success'} + hook = EmrServerlessHook() + waiter_response = hook.waiter( + get_state_callable=mock_call_function, + get_state_args={}, + parse_response=['layer1', 'layer2', 'response'], + desired_state=success_state, + failure_states={}, + object_type='test_object', + action='testing', + ) + mock_call_function.assert_called_once() + assert waiter_response is None + + def test_waiter_timeout(self): + mock_call_function = mock.MagicMock() + success_state = {'test_success'} + mock_call_function.return_value = {'response': 'pending'} + hook = EmrServerlessHook() + with pytest.raises(RuntimeError) as ex_message: + hook.waiter( + get_state_callable=mock_call_function, + get_state_args={}, + parse_response=['response'], + desired_state=success_state, + failure_states={}, + object_type='test_object', + action='testing', + check_interval_seconds=1, + countdown=3, + ) + assert mock_call_function.call_count == 4 + assert ( + str(ex_message.value) + == f'{"test_object".title()} still not {"testing".lower()} after the allocated time limit.' + ) diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py new file mode 100644 index 0000000000000..cdcea72949e55 --- /dev/null +++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py @@ -0,0 +1,409 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest import mock +from uuid import UUID + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.operators.emr import ( + EmrServerlessCreateApplicationOperator, + EmrServerlessDeleteApplicationOperator, + EmrServerlessStartJobOperator, +) + +task_id = 'test_emr_serverless_task_id' +application_id = 'test_application_id' +release_label = 'test' +job_type = 'test' +client_request_token = 'eac427d0-1c6d-4dfb9a-32423412' +config = {'name': 'test_application_emr_serverless'} + +execution_role_arn = 'test_emr_serverless_role_arn' +job_driver = {'test_key': 'test_value'} +configuration_overrides = {'monitoringConfiguration': {'test_key': 'test_value'}} +job_run_id = 'test_job_run_id' + +application_id_delete_operator = 'test_emr_serverless_delete_application_operator' + + +class TestEmrServerlessCreateApplicationOperator: + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") + def test_execute_successfully_with_wait_for_completion(self, mock_conn, mock_waiter): + mock_waiter.return_value = True + mock_conn.create_application.return_value = { + "applicationId": application_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + + operator = EmrServerlessCreateApplicationOperator( + task_id=task_id, + release_label=release_label, + job_type=job_type, + client_request_token=client_request_token, + config=config, + ) + + id = operator.execute(None) + + mock_conn.create_application.assert_called_once_with( + clientToken=client_request_token, + releaseLabel=release_label, + type=job_type, + **config, + ) + mock_conn.start_application.assert_called_once_with(applicationId=application_id) + + assert mock_waiter.call_count == 2 + assert id == application_id + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") + def test_execute_successfully_no_wait_for_completion(self, mock_conn, mock_waiter): + mock_waiter.return_value = True + mock_conn.create_application.return_value = { + "applicationId": application_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + + operator = EmrServerlessCreateApplicationOperator( + task_id=task_id, + release_label=release_label, + job_type=job_type, + client_request_token=client_request_token, + wait_for_completion=False, + config=config, + ) + + id = operator.execute(None) + + mock_conn.create_application.assert_called_once_with( + clientToken=client_request_token, + releaseLabel=release_label, + type=job_type, + **config, + ) + mock_conn.start_application.assert_called_once_with(applicationId=application_id) + + mock_waiter.assert_called_once() + assert id == application_id + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") + def test_failed_create_application(self, mock_conn, mock_waiter): + mock_waiter.return_value = True + mock_conn.create_application.return_value = { + "applicationId": application_id, + "ResponseMetadata": {"HTTPStatusCode": 404}, + } + + operator = EmrServerlessCreateApplicationOperator( + task_id=task_id, + release_label=release_label, + job_type=job_type, + client_request_token=client_request_token, + config=config, + ) + + with pytest.raises(AirflowException) as ex_message: + operator.execute(None) + + assert "Application Creation failed:" in str(ex_message.value) + + mock_conn.create_application.assert_called_once_with( + clientToken=client_request_token, + releaseLabel=release_label, + type=job_type, + **config, + ) + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") + def test_no_client_request_token(self, mock_conn, mock_waiter): + mock_waiter.return_value = True + mock_conn.create_application.return_value = { + "applicationId": application_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + + operator = EmrServerlessCreateApplicationOperator( + task_id=task_id, + release_label=release_label, + job_type=job_type, + wait_for_completion=False, + config=config, + ) + + operator.execute(None) + generated_client_token = operator.client_request_token + + assert str(UUID(generated_client_token, version=4)) == generated_client_token + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") + def test_application_in_failure_state(self, mock_conn): + fail_state = "STOPPED" + mock_conn.get_application.return_value = {"application": {"state": fail_state}} + mock_conn.create_application.return_value = { + "applicationId": application_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + + operator = EmrServerlessCreateApplicationOperator( + task_id=task_id, + release_label=release_label, + job_type=job_type, + client_request_token=client_request_token, + config=config, + ) + + with pytest.raises(AirflowException) as ex_message: + operator.execute(None) + + assert str(ex_message.value) == f"Application reached failure state {fail_state}." + + mock_conn.create_application.assert_called_once_with( + clientToken=client_request_token, + releaseLabel=release_label, + type=job_type, + **config, + ) + + +class TestEmrServerlessStartJobOperator: + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") + def test_job_run_app_started(self, mock_conn, mock_waiter): + mock_waiter.return_value = True + mock_conn.get_application.return_value = {"application": {"state": "STARTED"}} + mock_conn.start_job_run.return_value = { + 'jobRunId': job_run_id, + 'ResponseMetadata': {'HTTPStatusCode': 200}, + } + + operator = EmrServerlessStartJobOperator( + task_id=task_id, + client_request_token=client_request_token, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=job_driver, + configuration_overrides=configuration_overrides, + ) + + id = operator.execute(None) + + assert operator.wait_for_completion is True + mock_conn.get_application.assert_called_once_with(applicationId=application_id) + mock_waiter.assert_called_once() + assert id == job_run_id + mock_conn.start_job_run.assert_called_once_with( + clientToken=client_request_token, + applicationId=application_id, + executionRoleArn=execution_role_arn, + jobDriver=job_driver, + configurationOverrides=configuration_overrides, + ) + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") + def test_job_run_app_not_started(self, mock_conn, mock_waiter): + mock_waiter.return_value = True + mock_conn.get_application.return_value = {"application": {"state": "CREATING"}} + mock_conn.start_job_run.return_value = { + 'jobRunId': job_run_id, + 'ResponseMetadata': {'HTTPStatusCode': 200}, + } + + operator = EmrServerlessStartJobOperator( + task_id=task_id, + client_request_token=client_request_token, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=job_driver, + configuration_overrides=configuration_overrides, + ) + + id = operator.execute(None) + + assert operator.wait_for_completion is True + mock_conn.get_application.assert_called_once_with(applicationId=application_id) + assert mock_waiter.call_count == 2 + assert id == job_run_id + mock_conn.start_job_run.assert_called_once_with( + clientToken=client_request_token, + applicationId=application_id, + executionRoleArn=execution_role_arn, + jobDriver=job_driver, + configurationOverrides=configuration_overrides, + ) + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") + def test_job_run_app_not_started_no_wait_for_completion(self, mock_conn, mock_waiter): + mock_waiter.return_value = True + mock_conn.get_application.return_value = {"application": {"state": "CREATING"}} + mock_conn.start_job_run.return_value = { + 'jobRunId': job_run_id, + 'ResponseMetadata': {'HTTPStatusCode': 200}, + } + + operator = EmrServerlessStartJobOperator( + task_id=task_id, + client_request_token=client_request_token, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=job_driver, + configuration_overrides=configuration_overrides, + wait_for_completion=False, + ) + + id = operator.execute(None) + + mock_conn.get_application.assert_called_once_with(applicationId=application_id) + mock_waiter.assert_called_once() + assert id == job_run_id + mock_conn.start_job_run.assert_called_once_with( + clientToken=client_request_token, + applicationId=application_id, + executionRoleArn=execution_role_arn, + jobDriver=job_driver, + configurationOverrides=configuration_overrides, + ) + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") + def test_job_run_app_started_no_wait_for_completion(self, mock_conn, mock_waiter): + mock_waiter.return_value = True + mock_conn.get_application.return_value = {"application": {"state": "STARTED"}} + mock_conn.start_job_run.return_value = { + 'jobRunId': job_run_id, + 'ResponseMetadata': {'HTTPStatusCode': 200}, + } + + operator = EmrServerlessStartJobOperator( + task_id=task_id, + client_request_token=client_request_token, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=job_driver, + configuration_overrides=configuration_overrides, + wait_for_completion=False, + ) + + id = operator.execute(None) + assert id == job_run_id + mock_conn.start_job_run.assert_called_once_with( + clientToken=client_request_token, + applicationId=application_id, + executionRoleArn=execution_role_arn, + jobDriver=job_driver, + configurationOverrides=configuration_overrides, + ) + assert not mock_waiter.called + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") + def test_failed_start_job_run(self, mock_conn, mock_waiter): + mock_waiter.return_value = True + mock_conn.get_application.return_value = {"application": {"state": "CREATING"}} + mock_conn.start_job_run.return_value = { + 'jobRunId': job_run_id, + 'ResponseMetadata': {'HTTPStatusCode': 404}, + } + + operator = EmrServerlessStartJobOperator( + task_id=task_id, + client_request_token=client_request_token, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=job_driver, + configuration_overrides=configuration_overrides, + ) + with pytest.raises(AirflowException) as ex_message: + operator.execute(None) + + assert "EMR serverless job failed to start:" in str(ex_message.value) + mock_conn.get_application.assert_called_once_with(applicationId=application_id) + mock_waiter.assert_called_once() + mock_conn.start_job_run.assert_called_once_with( + clientToken=client_request_token, + applicationId=application_id, + executionRoleArn=execution_role_arn, + jobDriver=job_driver, + configurationOverrides=configuration_overrides, + ) + + +class TestEmrServerlessDeleteOperator: + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") + def test_delete_application_with_wait_for_completion_successfully(self, mock_conn, mock_waiter): + mock_waiter.return_value = True + mock_conn.stop_application.return_value = {} + mock_conn.delete_application.return_value = {'ResponseMetadata': {'HTTPStatusCode': 200}} + + operator = EmrServerlessDeleteApplicationOperator( + task_id=task_id, application_id=application_id_delete_operator + ) + + operator.execute(None) + + assert operator.wait_for_completion is True + assert mock_waiter.call_count == 2 + mock_conn.stop_application.assert_called_once() + mock_conn.delete_application.assert_called_once_with(applicationId=application_id_delete_operator) + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") + def test_delete_application_without_wait_for_completion_successfully(self, mock_conn, mock_waiter): + mock_waiter.return_value = True + mock_conn.stop_application.return_value = {} + mock_conn.delete_application.return_value = {'ResponseMetadata': {'HTTPStatusCode': 200}} + + operator = EmrServerlessDeleteApplicationOperator( + task_id=task_id, + application_id=application_id_delete_operator, + wait_for_completion=False, + ) + + operator.execute(None) + + assert operator.wait_for_completion is False + mock_waiter.assert_called_once() + mock_conn.stop_application.assert_called_once() + mock_conn.delete_application.assert_called_once_with(applicationId=application_id_delete_operator) + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") + def test_delete_application_failed_deleteion(self, mock_conn, mock_waiter): + mock_waiter.return_value = True + mock_conn.stop_application.return_value = {} + mock_conn.delete_application.return_value = {'ResponseMetadata': {'HTTPStatusCode': 400}} + + operator = EmrServerlessDeleteApplicationOperator( + task_id=task_id, application_id=application_id_delete_operator + ) + with pytest.raises(AirflowException) as ex_message: + operator.execute(None) + + assert "Application deletion failed:" in str(ex_message.value) + + assert operator.wait_for_completion is True + mock_waiter.assert_called_once() + mock_conn.stop_application.assert_called_once() + mock_conn.delete_application.assert_called_once_with(applicationId=application_id_delete_operator)