diff --git a/airflow/providers/amazon/aws/hooks/step_function.py b/airflow/providers/amazon/aws/hooks/step_function.py new file mode 100644 index 0000000000000..f0e10400d95ee --- /dev/null +++ b/airflow/providers/amazon/aws/hooks/step_function.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from typing import Optional, Union + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class StepFunctionHook(AwsBaseHook): + """ + Interact with an AWS Step Functions State Machine. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, region_name=None, *args, **kwargs): + super().__init__(client_type='stepfunctions', *args, **kwargs) + + def start_execution(self, state_machine_arn: str, name: Optional[str] = None, + state_machine_input: Union[dict, str, None] = None) -> str: + """ + Start Execution of the State Machine. + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.start_execution + + :param state_machine_arn: AWS Step Function State Machine ARN + :type state_machine_arn: str + :param name: The name of the execution. + :type name: Optional[str] + :param state_machine_input: JSON data input to pass to the State Machine + :type state_machine_input: Union[Dict[str, any], str, None] + :return: Execution ARN + :rtype: str + """ + execution_args = { + 'stateMachineArn': state_machine_arn + } + if name is not None: + execution_args['name'] = name + if state_machine_input is not None: + if isinstance(state_machine_input, str): + execution_args['input'] = state_machine_input + elif isinstance(state_machine_input, dict): + execution_args['input'] = json.dumps(state_machine_input) + + self.log.info('Executing Step Function State Machine: %s', state_machine_arn) + + response = self.conn.start_execution(**execution_args) + return response.get('executionArn', None) + + def describe_execution(self, execution_arn: str) -> dict: + """ + Describes a State Machine Execution + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.describe_execution + + :param execution_arn: ARN of the State Machine Execution + :type execution_arn: str + :return: Dict with Execution details + :rtype: dict + """ + return self.get_conn().describe_execution(executionArn=execution_arn) diff --git a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py new file mode 100644 index 0000000000000..2ef531c782b3f --- /dev/null +++ b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook +from airflow.utils.decorators import apply_defaults + + +class StepFunctionGetExecutionOutputOperator(BaseOperator): + """ + An Operator that begins execution of an Step Function State Machine + + Additional arguments may be specified and are passed down to the underlying BaseOperator. + + .. seealso:: + :class:`~airflow.models.BaseOperator` + + :param execution_arn: ARN of the Step Function State Machine Execution + :type execution_arn: str + :param aws_conn_id: aws connection to use, defaults to 'aws_default' + :type aws_conn_id: str + """ + template_fields = ['execution_arn'] + template_ext = () + ui_color = '#f9c915' + + @apply_defaults + def __init__(self, execution_arn: str, aws_conn_id='aws_default', region_name=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.execution_arn = execution_arn + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + def execute(self, context): + hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + + execution_status = hook.describe_execution(self.execution_arn) + execution_output = json.loads(execution_status['output']) if 'output' in execution_status else None + + self.log.info('Got State Machine Execution output for %s', self.execution_arn) + + return execution_output diff --git a/airflow/providers/amazon/aws/operators/step_function_start_execution.py b/airflow/providers/amazon/aws/operators/step_function_start_execution.py new file mode 100644 index 0000000000000..f5ea75ca3994d --- /dev/null +++ b/airflow/providers/amazon/aws/operators/step_function_start_execution.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook +from airflow.utils.decorators import apply_defaults + + +class StepFunctionStartExecutionOperator(BaseOperator): + """ + An Operator that begins execution of an Step Function State Machine + + Additional arguments may be specified and are passed down to the underlying BaseOperator. + + .. seealso:: + :class:`~airflow.models.BaseOperator` + + :param state_machine_arn: ARN of the Step Function State Machine + :type state_machine_arn: str + :param name: The name of the execution. + :type name: Optional[str] + :param state_machine_input: JSON data input to pass to the State Machine + :type state_machine_input: Union[Dict[str, any], str, None] + :param aws_conn_id: aws connection to uses + :type aws_conn_id: str + :param do_xcom_push: if True, execution_arn is pushed to XCom with key execution_arn. + :type do_xcom_push: bool + """ + template_fields = ['state_machine_arn', 'name', 'input'] + template_ext = () + ui_color = '#f9c915' + + @apply_defaults + def __init__(self, state_machine_arn: str, name: Optional[str] = None, + state_machine_input: Union[dict, str, None] = None, + aws_conn_id='aws_default', region_name=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.state_machine_arn = state_machine_arn + self.name = name + self.input = state_machine_input + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + def execute(self, context): + hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + + execution_arn = hook.start_execution(self.state_machine_arn, self.name, self.input) + + if execution_arn is None: + raise AirflowException(f'Failed to start State Machine execution for: {self.state_machine_arn}') + + self.log.info('Started State Machine execution for %s: %s', self.state_machine_arn, execution_arn) + + return execution_arn diff --git a/airflow/providers/amazon/aws/sensors/step_function_execution.py b/airflow/providers/amazon/aws/sensors/step_function_execution.py new file mode 100644 index 0000000000000..0cc3caf271806 --- /dev/null +++ b/airflow/providers/amazon/aws/sensors/step_function_execution.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook +from airflow.sensors.base_sensor_operator import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class StepFunctionExecutionSensor(BaseSensorOperator): + """ + Asks for the state of the Step Function State Machine Execution until it + reaches a failure state or success state. + If it fails, failing the task. + + On successful completion of the Execution the Sensor will do an XCom Push + of the State Machine's output to `output` + + :param execution_arn: execution_arn to check the state of + :type execution_arn: str + :param aws_conn_id: aws connection to use, defaults to 'aws_default' + :type aws_conn_id: str + """ + + INTERMEDIATE_STATES = ('RUNNING',) + FAILURE_STATES = ('FAILED', 'TIMED_OUT', 'ABORTED',) + SUCCESS_STATES = ('SUCCEEDED',) + + template_fields = ['execution_arn'] + template_ext = () + ui_color = '#66c3ff' + + @apply_defaults + def __init__(self, execution_arn: str, aws_conn_id='aws_default', region_name=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.execution_arn = execution_arn + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.hook = None + + def poke(self, context): + execution_status = self.get_hook().describe_execution(self.execution_arn) + state = execution_status['status'] + output = json.loads(execution_status['output']) if 'output' in execution_status else None + + if state in self.FAILURE_STATES: + raise AirflowException(f'Step Function sensor failed. State Machine Output: {output}') + + if state in self.INTERMEDIATE_STATES: + return False + + self.log.info('Doing xcom_push of output') + self.xcom_push(context, 'output', output) + return True + + def get_hook(self): + """Create and return a StepFunctionHook""" + if not self.hook: + self.hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + return self.hook diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst index 0c052e93ca9b8..355e443366b54 100644 --- a/docs/operators-and-hooks-ref.rst +++ b/docs/operators-and-hooks-ref.rst @@ -437,6 +437,13 @@ These integrations allow you to perform various operations within the Amazon Web - :mod:`airflow.providers.amazon.aws.sensors.s3_key`, :mod:`airflow.providers.amazon.aws.sensors.s3_prefix` + * - `AWS Step Functions `__ + - + - :mod:`airflow.providers.amazon.aws.hooks.step_function` + - :mod:`airflow.providers.amazon.aws.operators.step_function_start_execution`, + :mod:`airflow.providers.amazon.aws.operators.step_function_get_execution_output`, + - :mod:`airflow.providers.amazon.aws.sensors.step_function_execution`, + Transfer operators and hooks '''''''''''''''''''''''''''' diff --git a/tests/providers/amazon/aws/hooks/test_step_function.py b/tests/providers/amazon/aws/hooks/test_step_function.py new file mode 100644 index 0000000000000..679d2e44039ad --- /dev/null +++ b/tests/providers/amazon/aws/hooks/test_step_function.py @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import unittest + +from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook + +try: + from moto import mock_stepfunctions +except ImportError: + mock_stepfunctions = None + + +@unittest.skipIf(mock_stepfunctions is None, 'moto package not present') +class TestStepFunctionHook(unittest.TestCase): + + @mock_stepfunctions + def test_get_conn_returns_a_boto3_connection(self): + hook = StepFunctionHook(aws_conn_id='aws_default') + self.assertEqual('stepfunctions', hook.get_conn().meta.service_model.service_name) + + @mock_stepfunctions + def test_start_execution(self): + hook = StepFunctionHook(aws_conn_id='aws_default', region_name='us-east-1') + state_machine = hook.get_conn().create_state_machine( + name='pseudo-state-machine', definition='{}', roleArn='arn:aws:iam::000000000000:role/Role') + + state_machine_arn = state_machine.get('stateMachineArn', None) + + execution_arn = hook.start_execution( + state_machine_arn=state_machine_arn, name=None, state_machine_input={}) + + assert execution_arn is not None + + @mock_stepfunctions + def test_describe_execution(self): + hook = StepFunctionHook(aws_conn_id='aws_default', region_name='us-east-1') + state_machine = hook.get_conn().create_state_machine( + name='pseudo-state-machine', definition='{}', roleArn='arn:aws:iam::000000000000:role/Role') + + state_machine_arn = state_machine.get('stateMachineArn', None) + + execution_arn = hook.start_execution( + state_machine_arn=state_machine_arn, name=None, state_machine_input={}) + response = hook.describe_execution(execution_arn) + + assert 'input' in response diff --git a/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py b/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py new file mode 100644 index 0000000000000..8997df9fdaba8 --- /dev/null +++ b/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import unittest +from unittest import mock +from unittest.mock import MagicMock + +from airflow.providers.amazon.aws.operators.step_function_get_execution_output import ( + StepFunctionGetExecutionOutputOperator, +) + +TASK_ID = 'step_function_get_execution_output' +EXECUTION_ARN = 'arn:aws:states:us-east-1:123456789012:execution:'\ + 'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' +AWS_CONN_ID = 'aws_non_default' +REGION_NAME = 'us-west-2' + + +class TestStepFunctionGetExecutionOutputOperator(unittest.TestCase): + + def setUp(self): + self.mock_context = MagicMock() + + def test_init(self): + # Given / When + operator = StepFunctionGetExecutionOutputOperator( + task_id=TASK_ID, + execution_arn=EXECUTION_ARN, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + # Then + self.assertEqual(TASK_ID, operator.task_id) + self.assertEqual(EXECUTION_ARN, operator.execution_arn) + self.assertEqual(AWS_CONN_ID, operator.aws_conn_id) + self.assertEqual(REGION_NAME, operator.region_name) + + @mock.patch('airflow.providers.amazon.aws.operators.step_function_get_execution_output.StepFunctionHook') + def test_execute(self, mock_hook): + # Given + hook_response = { + 'output': '{}' + } + + hook_instance = mock_hook.return_value + hook_instance.describe_execution.return_value = hook_response + + operator = StepFunctionGetExecutionOutputOperator( + task_id=TASK_ID, + execution_arn=EXECUTION_ARN, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + # When + result = operator.execute(self.mock_context) + + # Then + self.assertEqual({}, result) diff --git a/tests/providers/amazon/aws/operators/test_step_function_start_execution.py b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py new file mode 100644 index 0000000000000..5f6c336521594 --- /dev/null +++ b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import unittest +from unittest import mock +from unittest.mock import MagicMock + +from airflow.providers.amazon.aws.operators.step_function_start_execution import ( + StepFunctionStartExecutionOperator, +) + +TASK_ID = 'step_function_start_execution_task' +STATE_MACHINE_ARN = 'arn:aws:states:us-east-1:000000000000:stateMachine:pseudo-state-machine' +NAME = 'NAME' +INPUT = '{}' +AWS_CONN_ID = 'aws_non_default' +REGION_NAME = 'us-west-2' + + +class TestStepFunctionStartExecutionOperator(unittest.TestCase): + + def setUp(self): + self.mock_context = MagicMock() + + def test_init(self): + # Given / When + operator = StepFunctionStartExecutionOperator( + task_id=TASK_ID, + state_machine_arn=STATE_MACHINE_ARN, + name=NAME, + state_machine_input=INPUT, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + # Then + self.assertEqual(TASK_ID, operator.task_id) + self.assertEqual(STATE_MACHINE_ARN, operator.state_machine_arn) + self.assertEqual(NAME, operator.name) + self.assertEqual(INPUT, operator.input) + self.assertEqual(AWS_CONN_ID, operator.aws_conn_id) + self.assertEqual(REGION_NAME, operator.region_name) + + @mock.patch('airflow.providers.amazon.aws.operators.step_function_start_execution.StepFunctionHook') + def test_execute(self, mock_hook): + # Given + hook_response = 'arn:aws:states:us-east-1:123456789012:execution:'\ + 'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' + + hook_instance = mock_hook.return_value + hook_instance.start_execution.return_value = hook_response + + operator = StepFunctionStartExecutionOperator( + task_id=TASK_ID, + state_machine_arn=STATE_MACHINE_ARN, + name=NAME, + state_machine_input=INPUT, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + # When + result = operator.execute(self.mock_context) + + # Then + self.assertEqual(hook_response, result) diff --git a/tests/providers/amazon/aws/sensors/test_step_function_execution.py b/tests/providers/amazon/aws/sensors/test_step_function_execution.py new file mode 100644 index 0000000000000..237f8ef424cb2 --- /dev/null +++ b/tests/providers/amazon/aws/sensors/test_step_function_execution.py @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +from unittest import mock +from unittest.mock import MagicMock + +from parameterized import parameterized + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.sensors.step_function_execution import StepFunctionExecutionSensor + +TASK_ID = 'step_function_execution_sensor' +EXECUTION_ARN = 'arn:aws:states:us-east-1:123456789012:execution:'\ + 'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' +AWS_CONN_ID = 'aws_non_default' +REGION_NAME = 'us-west-2' + + +class TestStepFunctionExecutionSensor(unittest.TestCase): + + def setUp(self): + self.mock_context = MagicMock() + + def test_init(self): + sensor = StepFunctionExecutionSensor( + task_id=TASK_ID, + execution_arn=EXECUTION_ARN, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + self.assertEqual(TASK_ID, sensor.task_id) + self.assertEqual(EXECUTION_ARN, sensor.execution_arn) + self.assertEqual(AWS_CONN_ID, sensor.aws_conn_id) + self.assertEqual(REGION_NAME, sensor.region_name) + + @parameterized.expand([('FAILED',), ('TIMED_OUT',), ('ABORTED',)]) + @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') + def test_exceptions(self, mock_status, mock_hook): + hook_response = { + 'status': mock_status + } + + hook_instance = mock_hook.return_value + hook_instance.describe_execution.return_value = hook_response + + sensor = StepFunctionExecutionSensor( + task_id=TASK_ID, + execution_arn=EXECUTION_ARN, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + with self.assertRaises(AirflowException): + sensor.poke(self.mock_context) + + @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') + def test_running(self, mock_hook): + hook_response = { + 'status': 'RUNNING' + } + + hook_instance = mock_hook.return_value + hook_instance.describe_execution.return_value = hook_response + + sensor = StepFunctionExecutionSensor( + task_id=TASK_ID, + execution_arn=EXECUTION_ARN, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + self.assertFalse(sensor.poke(self.mock_context)) + + @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') + def test_succeeded(self, mock_hook): + hook_response = { + 'status': 'SUCCEEDED' + } + + hook_instance = mock_hook.return_value + hook_instance.describe_execution.return_value = hook_response + + sensor = StepFunctionExecutionSensor( + task_id=TASK_ID, + execution_arn=EXECUTION_ARN, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + self.assertTrue(sensor.poke(self.mock_context))