From 6cd35442a0fc8a9f7ca03f6b9d4d10c490ff5bab Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Fri, 11 Aug 2023 16:22:23 -0400 Subject: [PATCH 1/3] Add `deferrable` option to `LambdaCreateFunctionOperator` --- .../amazon/aws/operators/lambda_function.py | 36 ++++++++++- .../providers/amazon/aws/triggers/athena.py | 2 +- airflow/providers/amazon/aws/triggers/base.py | 1 - .../amazon/aws/triggers/lambda_function.py | 64 +++++++++++++++++++ .../operators/lambda.rst | 2 + .../aws/operators/test_lambda_function.py | 15 +++++ .../aws/triggers/test_lambda_function.py | 54 ++++++++++++++++ 7 files changed, 171 insertions(+), 3 deletions(-) create mode 100644 airflow/providers/amazon/aws/triggers/lambda_function.py create mode 100644 tests/providers/amazon/aws/triggers/test_lambda_function.py diff --git a/airflow/providers/amazon/aws/operators/lambda_function.py b/airflow/providers/amazon/aws/operators/lambda_function.py index 28b6313204221..52733be8793fe 100644 --- a/airflow/providers/amazon/aws/operators/lambda_function.py +++ b/airflow/providers/amazon/aws/operators/lambda_function.py @@ -18,11 +18,15 @@ from __future__ import annotations import json +from datetime import timedelta from functools import cached_property -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Any, Sequence +from airflow import AirflowException +from airflow.configuration import conf from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook +from airflow.providers.amazon.aws.triggers.lambda_function import LambdaCreateFunctionCompleteTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -50,6 +54,11 @@ class LambdaCreateFunctionOperator(BaseOperator): :param timeout: The amount of time (in seconds) that Lambda allows a function to run before stopping it. :param config: Optional dictionary for arbitrary parameters to the boto API create_lambda call. :param wait_for_completion: If True, the operator will wait until the function is active. + :param waiter_max_attempts: Maximum number of attempts to poll the creation. + :param waiter_delay: Number of seconds between polling the state of the creation. + :param deferrable: If True, the operator will wait asynchronously for the creation to complete. + This implies waiting for creation complete. This mode requires aiobotocore module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) :param aws_conn_id: The AWS connection ID to use """ @@ -75,6 +84,9 @@ def __init__( timeout: int | None = None, config: dict = {}, wait_for_completion: bool = False, + waiter_max_attempts: int = 60, + waiter_delay: int = 15, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), aws_conn_id: str = "aws_default", **kwargs, ): @@ -88,6 +100,9 @@ def __init__( self.timeout = timeout self.config = config self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable self.aws_conn_id = aws_conn_id @cached_property @@ -108,6 +123,18 @@ def execute(self, context: Context): ) self.log.info("Lambda response: %r", response) + if self.deferrable: + self.defer( + trigger=LambdaCreateFunctionCompleteTrigger( + function_name=self.function_name, + function_arn=response["FunctionArn"], + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay), + ) if self.wait_for_completion: self.log.info("Wait for Lambda function to be active") waiter = self.hook.conn.get_waiter("function_active_v2") @@ -117,6 +144,13 @@ def execute(self, context: Context): return response.get("FunctionArn") + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + if not event or event["status"] != "success": + raise AirflowException(f"Trigger error: event is {event}") + + self.log.info("Lambda function created successfully") + return event["function_arn"] + class LambdaInvokeFunctionOperator(BaseOperator): """ diff --git a/airflow/providers/amazon/aws/triggers/athena.py b/airflow/providers/amazon/aws/triggers/athena.py index 636c1350598ea..fcd09dae19039 100644 --- a/airflow/providers/amazon/aws/triggers/athena.py +++ b/airflow/providers/amazon/aws/triggers/athena.py @@ -23,7 +23,7 @@ class AthenaTrigger(AwsBaseWaiterTrigger): """ - Trigger for RedshiftCreateClusterOperator. + Trigger for AthenaOperator. The trigger will asynchronously poll the boto3 API and wait for the Redshift cluster to be in the `available` state. diff --git a/airflow/providers/amazon/aws/triggers/base.py b/airflow/providers/amazon/aws/triggers/base.py index 41f7d2dc33d79..d2d664d97f521 100644 --- a/airflow/providers/amazon/aws/triggers/base.py +++ b/airflow/providers/amazon/aws/triggers/base.py @@ -112,7 +112,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]: @abstractmethod def hook(self) -> AwsGenericHook: """Override in subclasses to return the right hook.""" - ... async def run(self) -> AsyncIterator[TriggerEvent]: hook = self.hook() diff --git a/airflow/providers/amazon/aws/triggers/lambda_function.py b/airflow/providers/amazon/aws/triggers/lambda_function.py new file mode 100644 index 0000000000000..f0f6a40551b96 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/lambda_function.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook +from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger + + +class LambdaCreateFunctionCompleteTrigger(AwsBaseWaiterTrigger): + """ + Trigger to poll for the completion of a Lambda function creation. + + :param function_name: The function name + :param function_arn: The function ARN + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + *, + function_name: str, + function_arn: str, + waiter_delay: int = 60, + waiter_max_attempts: int = 30, + aws_conn_id: str | None = None, + ) -> None: + + super().__init__( + serialized_fields={"function_name": function_name, "function_arn": function_arn}, + waiter_name="function_active_v2", + waiter_args={"FunctionName": function_name}, + failure_message="Lambda function creation failed", + status_message="Status of Lambda function creation is", + status_queries=[ + "Configuration.LastUpdateStatus", + "Configuration.LastUpdateStatusReason", + "Configuration.LastUpdateStatusReasonCode", + ], + return_key="function_arn", + return_value=function_arn, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + + def hook(self) -> AwsGenericHook: + return LambdaHook(aws_conn_id=self.aws_conn_id) diff --git a/docs/apache-airflow-providers-amazon/operators/lambda.rst b/docs/apache-airflow-providers-amazon/operators/lambda.rst index 79649f106fdee..0149b5b620c75 100644 --- a/docs/apache-airflow-providers-amazon/operators/lambda.rst +++ b/docs/apache-airflow-providers-amazon/operators/lambda.rst @@ -40,6 +40,8 @@ Create an AWS Lambda function To create an AWS lambda function you can use :class:`~airflow.providers.amazon.aws.operators.lambda_function.LambdaCreateFunctionOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires +the aiobotocore module to be installed. .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_lambda.py :language: python diff --git a/tests/providers/amazon/aws/operators/test_lambda_function.py b/tests/providers/amazon/aws/operators/test_lambda_function.py index f0b4b834eb00d..6fc3a3b64f57c 100644 --- a/tests/providers/amazon/aws/operators/test_lambda_function.py +++ b/tests/providers/amazon/aws/operators/test_lambda_function.py @@ -22,6 +22,7 @@ import pytest +from airflow.exceptions import TaskDeferred from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook from airflow.providers.amazon.aws.operators.lambda_function import ( LambdaCreateFunctionOperator, @@ -69,6 +70,20 @@ def test_create_lambda_with_wait_for_completion(self, mock_hook_conn, mock_hook_ mock_hook_create_lambda.assert_called_once() mock_hook_conn.get_waiter.assert_called_once_with("function_active_v2") + @mock.patch.object(LambdaHook, "create_lambda") + def test_create_lambda_deferrable(self, _): + operator = LambdaCreateFunctionOperator( + task_id="task_test", + function_name=FUNCTION_NAME, + role=ROLE_ARN, + code={ + "ImageUri": IMAGE_URI, + }, + deferrable=True, + ) + with pytest.raises(TaskDeferred): + operator.execute(None) + class TestLambdaInvokeFunctionOperator: @pytest.mark.parametrize( diff --git a/tests/providers/amazon/aws/triggers/test_lambda_function.py b/tests/providers/amazon/aws/triggers/test_lambda_function.py new file mode 100644 index 0000000000000..c06a99d42e17a --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_lambda_function.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.amazon.aws.triggers.lambda_function import LambdaCreateFunctionCompleteTrigger + +TEST_FUNCTION_NAME = "test-function-name" +TEST_FUNCTION_ARN = "test-function-arn" +TEST_WAITER_DELAY = 10 +TEST_WAITER_MAX_ATTEMPTS = 10 +TEST_AWS_CONN_ID = "test-conn-id" +TEST_REGION_NAME = "test-region-name" + + +class TestLambdaFunctionTriggers: + @pytest.mark.parametrize( + "trigger", + [ + LambdaCreateFunctionCompleteTrigger( + function_name=TEST_FUNCTION_NAME, + function_arn=TEST_FUNCTION_ARN, + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, + ) + ], + ) + def test_serialize_recreate(self, trigger): + class_path, args = trigger.serialize() + + class_name = class_path.split(".")[-1] + clazz = globals()[class_name] + instance = clazz(**args) + + class_path2, args2 = instance.serialize() + + assert class_path == class_path2 + assert args == args2 From d332338fd14aab5d710dbb46d2dfe791c16f37a7 Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Fri, 11 Aug 2023 16:42:20 -0400 Subject: [PATCH 2/3] Fix static checks --- airflow/providers/amazon/provider.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 284cc429fa5dc..6aa32825105bf 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -544,6 +544,9 @@ triggers: - integration-name: Amazon EC2 python-modules: - airflow.providers.amazon.aws.triggers.ec2 + - integration-name: AWS Lambda + python-modules: + - airflow.providers.amazon.aws.triggers.lambda_function - integration-name: Amazon Redshift python-modules: - airflow.providers.amazon.aws.triggers.redshift_cluster From 55fcc783aba446a89af1d9c0ab4a61a9acaed44b Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Mon, 14 Aug 2023 14:21:49 -0400 Subject: [PATCH 3/3] Fix type annotation --- airflow/providers/amazon/aws/operators/lambda_function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/operators/lambda_function.py b/airflow/providers/amazon/aws/operators/lambda_function.py index 66de2246bfa8a..5d7e980bb53ac 100644 --- a/airflow/providers/amazon/aws/operators/lambda_function.py +++ b/airflow/providers/amazon/aws/operators/lambda_function.py @@ -144,7 +144,7 @@ def execute(self, context: Context): return response.get("FunctionArn") - def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: if not event or event["status"] != "success": raise AirflowException(f"Trigger error: event is {event}")