From 7371cd46b7ee64365289945e47022ab6f5814e42 Mon Sep 17 00:00:00 2001 From: Chauncy McCaughey Date: Wed, 6 May 2020 15:09:26 -0600 Subject: [PATCH 01/16] add aws StepFunctions integrations --- .../amazon/aws/hooks/step_function.py | 80 ++++++++++ .../step_function_get_execution_output.py | 65 ++++++++ .../step_function_start_execution.py | 78 ++++++++++ .../aws/sensors/step_function_execution.py | 81 ++++++++++ .../amazon/aws/hooks/test_step_function.py | 61 ++++++++ ...test_step_function_get_execution_output.py | 73 +++++++++ .../test_step_function_start_execution.py | 79 ++++++++++ .../sensors/test_step_function_execution.py | 141 ++++++++++++++++++ 8 files changed, 658 insertions(+) create mode 100644 airflow/providers/amazon/aws/hooks/step_function.py create mode 100644 airflow/providers/amazon/aws/operators/step_function_get_execution_output.py create mode 100644 airflow/providers/amazon/aws/operators/step_function_start_execution.py create mode 100644 airflow/providers/amazon/aws/sensors/step_function_execution.py create mode 100644 tests/providers/amazon/aws/hooks/test_step_function.py create mode 100644 tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py create mode 100644 tests/providers/amazon/aws/operators/test_step_function_start_execution.py create mode 100644 tests/providers/amazon/aws/sensors/test_step_function_execution.py diff --git a/airflow/providers/amazon/aws/hooks/step_function.py b/airflow/providers/amazon/aws/hooks/step_function.py new file mode 100644 index 0000000000000..7f52c1e9f221c --- /dev/null +++ b/airflow/providers/amazon/aws/hooks/step_function.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json + +from typing import Dict, Optional, Union + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class StepFunctionHook(AwsBaseHook): + """ + Interact with an AWS Step Functions State Machine. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, region_name=None, *args, **kwargs): + super().__init__(client_type='stepfunctions', *args, **kwargs) + + def start_execution(self, state_machine_arn: str, name: Optional[str] = None, + input: Union[Dict[str, any], str, None] = None): + """ + Start Execution of the State Machine. + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.start_execution + + :param state_machine_arn: AWS Step Function State Machine ARN + :type state_machine_arn: str + :param name: The name of the execution. + :type name: Optional[str] + :param input: JSON data input to pass to the State Machine + :type input: Union[Dict[str, any], str, None] + :return: Execution ARN + :rtype: str + """ + execution_args = { + 'stateMachineArn': state_machine_arn + } + if name is not None: + execution_args['name'] = name + if input is not None: + if isinstance(input, str): + execution_args['input'] = str + elif isinstance(input, dict): + execution_args['input'] = json.dumps(input) + + self.log.info(f'Executing Step Function State Machine: {state_machine_arn}') + + response = self.conn.start_execution(**execution_args) + return response['executionArn'] if 'executionArn' in response else None + + def describe_execution(self, execution_arn: str): + """ + Describes a State Machine Execution + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.describe_execution + + :param execution_arn: ARN of the State Machine Execution + :type execution_arn: str + :return: Dict with Execution details + :rtype: dict + """ + return self.get_conn().describe_execution(executionArn=execution_arn) diff --git a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py new file mode 100644 index 0000000000000..36ff75f57baed --- /dev/null +++ b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + +from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook + + +class StepFunctionGetExecutionOutputOperator(BaseOperator): + """ + An Operator that begins execution of an Step Function State Machine + + Additional arguments may be specified and are passed down to the underlying BaseOperator. + + .. seealso:: + :class:`~airflow.models.BaseOperator` + + :param execution_arn: ARN of the Step Function State Machine Execution + :type execution_arn: str + :param aws_conn_id: aws connection to use, defaults to 'aws_default' + :type aws_conn_id: str + """ + template_fields = ['execution_arn'] + template_ext = () + ui_color = '#f9c915' + + @apply_defaults + def __init__(self, execution_arn: str, aws_conn_id='aws_default', region_name=None, *args, **kwargs): + if kwargs.get('xcom_push') is not None: + raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead") + super().__init__(*args, **kwargs) + self.execution_arn = execution_arn + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + def execute(self, context): + hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + + execution_status = hook.describe_execution(self.execution_arn) + execution_output = json.loads(execution_status['output']) if 'output' in execution_status else None + + if self.do_xcom_push: + context['ti'].xcom_push(key='execution_output', value=execution_output) + + self.log.info(f'Got State Machine Execution output for {self.execution_arn}') + + return execution_output diff --git a/airflow/providers/amazon/aws/operators/step_function_start_execution.py b/airflow/providers/amazon/aws/operators/step_function_start_execution.py new file mode 100644 index 0000000000000..844f2e65a3f73 --- /dev/null +++ b/airflow/providers/amazon/aws/operators/step_function_start_execution.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Dict, Optional, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + +from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook + + +class StepFunctionStartExecutionOperator(BaseOperator): + """ + An Operator that begins execution of an Step Function State Machine + + Additional arguments may be specified and are passed down to the underlying BaseOperator. + + .. seealso:: + :class:`~airflow.models.BaseOperator` + + :param state_machine_arn: ARN of the Step Function State Machine + :type state_machine_arn: str + :param name: The name of the execution. + :type name: Optional[str] + :param input: JSON data input to pass to the State Machine + :type input: Union[Dict[str, any], str, None] + :param aws_conn_id: aws connection to uses + :type aws_conn_id: str + :param do_xcom_push: if True, execution_arn is pushed to XCom with key execution_arn. + :type do_xcom_push: bool + """ + template_fields = ['state_machine_arn', 'name', 'input'] + template_ext = () + ui_color = '#f9c915' + + @apply_defaults + def __init__(self, state_machine_arn: str, name: Optional[str] = None, + input: Union[Dict[str, any], str, None] = None, + aws_conn_id='aws_default', region_name=None, + *args, **kwargs): + if kwargs.get('xcom_push') is not None: + raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead") + super().__init__(*args, **kwargs) + self.state_machine_arn = state_machine_arn + self.name = name + self.input = input + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + def execute(self, context): + hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + + execution_arn = hook.start_execution(self.state_machine_arn, self.name, self.input) + + if execution_arn is None: + raise AirflowException(f'Failed to start State Machine execution for: {self.state_machine_arn}') + + if self.do_xcom_push: + context['ti'].xcom_push(key='execution_arn', value=execution_arn) + + self.log.info(f'Started State Machine execution for {self.state_machine_arn}: {execution_arn}') + + return execution_arn diff --git a/airflow/providers/amazon/aws/sensors/step_function_execution.py b/airflow/providers/amazon/aws/sensors/step_function_execution.py new file mode 100644 index 0000000000000..62da535471a2a --- /dev/null +++ b/airflow/providers/amazon/aws/sensors/step_function_execution.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json + +from airflow.exceptions import AirflowException +from airflow.sensors.base_sensor_operator import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + +from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook + + +class StepFunctionExecutionSensor(BaseSensorOperator): + """ + Asks for the state of the Step Function State Machine Execution until it + reaches a failure state or success state. + If it fails, failing the task. + + On successful completion of the Execution the Sensor will do an XCom Push + of the State Machine's output to `output` + + :param execution_arn: execution_arn to check the state of + :type execution_arn: str + :param aws_conn_id: aws connection to use, defaults to 'aws_default' + :type aws_conn_id: str + """ + + INTERMEDIATE_STATES = ('RUNNING',) + FAILURE_STATES = ('FAILED', 'TIMED_OUT', 'ABORTED',) + SUCCESS_STATES = ('SUCCEEDED',) + + template_fields = ['execution_arn'] + template_ext = () + ui_color = '#66c3ff' + + @apply_defaults + def __init__(self, execution_arn: str, aws_conn_id='aws_default', region_name=None, + *args, **kwargs): + if kwargs.get('xcom_push') is not None: + raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead") + super().__init__(*args, **kwargs) + self.execution_arn = execution_arn + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.hook = None + + def poke(self, context): + execution_status = self.get_hook().describe_execution(self.execution_arn) + state = execution_status['status'] + output = json.loads(execution_status['output']) if 'output' in execution_status else None + + if state in self.FAILURE_STATES: + raise AirflowException(f'Step Function sensor failed. State Machine Output: {output}') + + if state in self.INTERMEDIATE_STATES: + return False + + self.log.info(f'Doing xcom_push of output') + self.xcom_push(context, 'output', output) + return True + + def get_hook(self): + """Create and return a StepFunctionHook""" + if not self.hook: + self.log.info(f'region_name: {self.region_name}') + self.hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + return self.hook diff --git a/tests/providers/amazon/aws/hooks/test_step_function.py b/tests/providers/amazon/aws/hooks/test_step_function.py new file mode 100644 index 0000000000000..e025f75c24614 --- /dev/null +++ b/tests/providers/amazon/aws/hooks/test_step_function.py @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import unittest + +from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook + +try: + from moto import mock_stepfunctions +except ImportError: + mock_stepfunctions = None + + +@unittest.skipIf(mock_stepfunctions is None, 'moto package not present') +class TestStepFunctionHook(unittest.TestCase): + + @mock_stepfunctions + def test_get_conn_returns_a_boto3_connection(self): + hook = StepFunctionHook(aws_conn_id='aws_default') + self.assertIsNotNone(hook.get_conn()) + + @mock_stepfunctions + def test_start_execution(self): + hook = StepFunctionHook(aws_conn_id='aws_default', region_name='us-east-1') + state_machine = hook.get_conn().create_state_machine( + name='pseudo-state-machine', definition='{}', roleArn='arn:aws:iam::000000000000:role/Role') + + state_machine_arn = state_machine.get('stateMachineArn', None) + + execution_arn = hook.start_execution(state_machine_arn=state_machine_arn, name=None, input={}) + + assert execution_arn is not None + + @mock_stepfunctions + def test_describe_execution(self): + hook = StepFunctionHook(aws_conn_id='aws_default', region_name='us-east-1') + state_machine = hook.get_conn().create_state_machine( + name='pseudo-state-machine', definition='{}', roleArn='arn:aws:iam::000000000000:role/Role') + + state_machine_arn = state_machine.get('stateMachineArn', None) + + execution_arn = hook.start_execution(state_machine_arn=state_machine_arn, name=None, input={}) + response = hook.describe_execution(execution_arn) + + assert 'input' in response diff --git a/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py b/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py new file mode 100644 index 0000000000000..ff8eb9ed4df8c --- /dev/null +++ b/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import unittest +from unittest import mock +from unittest.mock import MagicMock + +from airflow.providers.amazon.aws.operators.step_function_get_execution_output import StepFunctionGetExecutionOutputOperator + +TASK_ID = 'step_function_get_execution_output' +EXECUTION_ARN = 'arn:aws:states:us-east-1:123456789012:execution:pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' +AWS_CONN_ID = 'aws_non_default' +REGION_NAME='us-west-2' + + +class TestStepFunctionGetExecutionOutputOperator(unittest.TestCase): + + def setUp(self): + self.mock_context = MagicMock() + + def test_init(self): + # Given / When + operator = StepFunctionGetExecutionOutputOperator( + task_id=TASK_ID, + execution_arn=EXECUTION_ARN, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + # Then + self.assertEqual(TASK_ID, operator.task_id) + self.assertEqual(EXECUTION_ARN, operator.execution_arn) + self.assertEqual(AWS_CONN_ID, operator.aws_conn_id) + self.assertEqual(REGION_NAME, operator.region_name) + + @mock.patch('airflow.providers.amazon.aws.operators.step_function_get_execution_output.StepFunctionHook') + def test_execute(self, mock_hook): + # Given + hook_response = { + 'output': '{}' + } + + hook_instance = mock_hook.return_value + hook_instance.describe_execution.return_value = hook_response + + operator = StepFunctionGetExecutionOutputOperator( + task_id=TASK_ID, + execution_arn=EXECUTION_ARN, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + # When + result = operator.execute(self.mock_context) + + # Then + self.assertEqual({}, result) diff --git a/tests/providers/amazon/aws/operators/test_step_function_start_execution.py b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py new file mode 100644 index 0000000000000..8af78a5460f2e --- /dev/null +++ b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py @@ -0,0 +1,79 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import unittest +from unittest import mock +from unittest.mock import MagicMock + +from airflow.providers.amazon.aws.operators.step_function_start_execution import StepFunctionStartExecutionOperator + +TASK_ID = 'step_function_start_execution_task' +STATE_MACHINE_ARN = 'arn:aws:states:us-east-1:000000000000:stateMachine:pseudo-state-machine' +NAME = 'NAME' +INPUT = '{}' +AWS_CONN_ID = 'aws_non_default' +REGION_NAME='us-west-2' + + +class TestStepFunctionStartExecutionOperator(unittest.TestCase): + + def setUp(self): + self.mock_context = MagicMock() + + def test_init(self): + # Given / When + operator = StepFunctionStartExecutionOperator( + task_id=TASK_ID, + state_machine_arn=STATE_MACHINE_ARN, + name=NAME, + input=INPUT, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + # Then + self.assertEqual(TASK_ID, operator.task_id) + self.assertEqual(STATE_MACHINE_ARN, operator.state_machine_arn) + self.assertEqual(NAME, operator.name) + self.assertEqual(INPUT, operator.input) + self.assertEqual(AWS_CONN_ID, operator.aws_conn_id) + self.assertEqual(REGION_NAME, operator.region_name) + + @mock.patch('airflow.providers.amazon.aws.operators.step_function_start_execution.StepFunctionHook') + def test_execute(self, mock_hook): + # Given + hook_response = 'arn:aws:states:us-east-1:123456789012:execution:pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' + + hook_instance = mock_hook.return_value + hook_instance.start_execution.return_value = hook_response + + operator = StepFunctionStartExecutionOperator( + task_id=TASK_ID, + state_machine_arn=STATE_MACHINE_ARN, + name=NAME, + input=INPUT, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + # When + result = operator.execute(self.mock_context) + + # Then + self.assertEqual(hook_response, result) diff --git a/tests/providers/amazon/aws/sensors/test_step_function_execution.py b/tests/providers/amazon/aws/sensors/test_step_function_execution.py new file mode 100644 index 0000000000000..f57c2cbd4b175 --- /dev/null +++ b/tests/providers/amazon/aws/sensors/test_step_function_execution.py @@ -0,0 +1,141 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +from unittest import mock +from unittest.mock import MagicMock + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.sensors.step_function_execution import StepFunctionExecutionSensor + +TASK_ID = 'step_function_execution_sensor' +EXECUTION_ARN = 'arn:aws:states:us-east-1:123456789012:execution:pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' +AWS_CONN_ID = 'aws_non_default' +REGION_NAME='us-west-2' + + +class TestStepFunctionExecutionSensor(unittest.TestCase): + + def setUp(self): + self.mock_context = MagicMock() + + def test_init(self): + sensor = StepFunctionExecutionSensor( + task_id=TASK_ID, + execution_arn=EXECUTION_ARN, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + self.assertEqual(TASK_ID, sensor.task_id) + self.assertEqual(EXECUTION_ARN, sensor.execution_arn) + self.assertEqual(AWS_CONN_ID, sensor.aws_conn_id) + self.assertEqual(REGION_NAME, sensor.region_name) + + @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') + def test_failed(self, mock_hook): + hook_response = { + 'status': 'FAILED' + } + + hook_instance = mock_hook.return_value + hook_instance.describe_execution.return_value = hook_response + + sensor = StepFunctionExecutionSensor( + task_id=TASK_ID, + execution_arn=EXECUTION_ARN, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + with self.assertRaises(AirflowException): + sensor.poke(self.mock_context) + + @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') + def test_timed_out(self, mock_hook): + hook_response = { + 'status': 'TIMED_OUT' + } + + hook_instance = mock_hook.return_value + hook_instance.describe_execution.return_value = hook_response + + sensor = StepFunctionExecutionSensor( + task_id=TASK_ID, + execution_arn=EXECUTION_ARN, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + with self.assertRaises(AirflowException): + sensor.poke(self.mock_context) + + @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') + def test_aborted(self, mock_hook): + hook_response = { + 'status': 'ABORTED' + } + + hook_instance = mock_hook.return_value + hook_instance.describe_execution.return_value = hook_response + + sensor = StepFunctionExecutionSensor( + task_id=TASK_ID, + execution_arn=EXECUTION_ARN, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + with self.assertRaises(AirflowException): + sensor.poke(self.mock_context) + + @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') + def test_running(self, mock_hook): + hook_response = { + 'status': 'RUNNING' + } + + hook_instance = mock_hook.return_value + hook_instance.describe_execution.return_value = hook_response + + sensor = StepFunctionExecutionSensor( + task_id=TASK_ID, + execution_arn=EXECUTION_ARN, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + self.assertFalse(sensor.poke(self.mock_context)) + + @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') + def test_succeeded(self, mock_hook): + hook_response = { + 'status': 'SUCCEEDED' + } + + hook_instance = mock_hook.return_value + hook_instance.describe_execution.return_value = hook_response + + sensor = StepFunctionExecutionSensor( + task_id=TASK_ID, + execution_arn=EXECUTION_ARN, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + self.assertTrue(sensor.poke(self.mock_context)) From c7eb8fc258ce6085faff80a9f52bfce7c7fd3448 Mon Sep 17 00:00:00 2001 From: Chauncy McCaughey Date: Wed, 6 May 2020 15:58:28 -0600 Subject: [PATCH 02/16] added docs for step functions hook and operators --- docs/operators-and-hooks-ref.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst index 0c052e93ca9b8..355e443366b54 100644 --- a/docs/operators-and-hooks-ref.rst +++ b/docs/operators-and-hooks-ref.rst @@ -437,6 +437,13 @@ These integrations allow you to perform various operations within the Amazon Web - :mod:`airflow.providers.amazon.aws.sensors.s3_key`, :mod:`airflow.providers.amazon.aws.sensors.s3_prefix` + * - `AWS Step Functions `__ + - + - :mod:`airflow.providers.amazon.aws.hooks.step_function` + - :mod:`airflow.providers.amazon.aws.operators.step_function_start_execution`, + :mod:`airflow.providers.amazon.aws.operators.step_function_get_execution_output`, + - :mod:`airflow.providers.amazon.aws.sensors.step_function_execution`, + Transfer operators and hooks '''''''''''''''''''''''''''' From 2ee44767abdc63a63bfc8e62438223e699ef22e6 Mon Sep 17 00:00:00 2001 From: Chauncy McCaughey Date: Wed, 6 May 2020 16:05:44 -0600 Subject: [PATCH 03/16] fixed pylint errors --- .../operators/test_step_function_get_execution_output.py | 8 +++++--- .../aws/operators/test_step_function_start_execution.py | 8 +++++--- .../amazon/aws/sensors/test_step_function_execution.py | 5 +++-- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py b/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py index ff8eb9ed4df8c..2aad2854b4e78 100644 --- a/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py +++ b/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py @@ -21,12 +21,14 @@ from unittest import mock from unittest.mock import MagicMock -from airflow.providers.amazon.aws.operators.step_function_get_execution_output import StepFunctionGetExecutionOutputOperator +from airflow.providers.amazon.aws.operators.step_function_get_execution_output import \ + StepFunctionGetExecutionOutputOperator TASK_ID = 'step_function_get_execution_output' -EXECUTION_ARN = 'arn:aws:states:us-east-1:123456789012:execution:pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' +EXECUTION_ARN = 'arn:aws:states:us-east-1:123456789012:execution:'\ + 'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' AWS_CONN_ID = 'aws_non_default' -REGION_NAME='us-west-2' +REGION_NAME = 'us-west-2' class TestStepFunctionGetExecutionOutputOperator(unittest.TestCase): diff --git a/tests/providers/amazon/aws/operators/test_step_function_start_execution.py b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py index 8af78a5460f2e..e6dfcd0b23e51 100644 --- a/tests/providers/amazon/aws/operators/test_step_function_start_execution.py +++ b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py @@ -21,14 +21,15 @@ from unittest import mock from unittest.mock import MagicMock -from airflow.providers.amazon.aws.operators.step_function_start_execution import StepFunctionStartExecutionOperator +from airflow.providers.amazon.aws.operators.step_function_start_execution \ + import StepFunctionStartExecutionOperator TASK_ID = 'step_function_start_execution_task' STATE_MACHINE_ARN = 'arn:aws:states:us-east-1:000000000000:stateMachine:pseudo-state-machine' NAME = 'NAME' INPUT = '{}' AWS_CONN_ID = 'aws_non_default' -REGION_NAME='us-west-2' +REGION_NAME = 'us-west-2' class TestStepFunctionStartExecutionOperator(unittest.TestCase): @@ -58,7 +59,8 @@ def test_init(self): @mock.patch('airflow.providers.amazon.aws.operators.step_function_start_execution.StepFunctionHook') def test_execute(self, mock_hook): # Given - hook_response = 'arn:aws:states:us-east-1:123456789012:execution:pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' + hook_response = 'arn:aws:states:us-east-1:123456789012:execution:'\ + 'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' hook_instance = mock_hook.return_value hook_instance.start_execution.return_value = hook_response diff --git a/tests/providers/amazon/aws/sensors/test_step_function_execution.py b/tests/providers/amazon/aws/sensors/test_step_function_execution.py index f57c2cbd4b175..23a255b2ae075 100644 --- a/tests/providers/amazon/aws/sensors/test_step_function_execution.py +++ b/tests/providers/amazon/aws/sensors/test_step_function_execution.py @@ -24,9 +24,10 @@ from airflow.providers.amazon.aws.sensors.step_function_execution import StepFunctionExecutionSensor TASK_ID = 'step_function_execution_sensor' -EXECUTION_ARN = 'arn:aws:states:us-east-1:123456789012:execution:pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' +EXECUTION_ARN = 'arn:aws:states:us-east-1:123456789012:execution:'\ + 'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' AWS_CONN_ID = 'aws_non_default' -REGION_NAME='us-west-2' +REGION_NAME = 'us-west-2' class TestStepFunctionExecutionSensor(unittest.TestCase): From 33b30e0204d85a5bfa551b5742ad75a6b42da9ba Mon Sep 17 00:00:00 2001 From: Chauncy McCaughey Date: Wed, 6 May 2020 17:01:15 -0600 Subject: [PATCH 04/16] apply isort fixes to imports --- airflow/providers/amazon/aws/hooks/step_function.py | 1 - .../aws/operators/step_function_get_execution_output.py | 3 +-- .../amazon/aws/operators/step_function_start_execution.py | 3 +-- .../providers/amazon/aws/sensors/step_function_execution.py | 3 +-- .../aws/operators/test_step_function_get_execution_output.py | 5 +++-- .../aws/operators/test_step_function_start_execution.py | 5 +++-- 6 files changed, 9 insertions(+), 11 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/step_function.py b/airflow/providers/amazon/aws/hooks/step_function.py index 7f52c1e9f221c..5a9793d68b2e7 100644 --- a/airflow/providers/amazon/aws/hooks/step_function.py +++ b/airflow/providers/amazon/aws/hooks/step_function.py @@ -16,7 +16,6 @@ # under the License. import json - from typing import Dict, Optional, Union from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook diff --git a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py index 36ff75f57baed..4a12b5e779465 100644 --- a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py +++ b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py @@ -19,9 +19,8 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults - from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook +from airflow.utils.decorators import apply_defaults class StepFunctionGetExecutionOutputOperator(BaseOperator): diff --git a/airflow/providers/amazon/aws/operators/step_function_start_execution.py b/airflow/providers/amazon/aws/operators/step_function_start_execution.py index 844f2e65a3f73..63ed099097676 100644 --- a/airflow/providers/amazon/aws/operators/step_function_start_execution.py +++ b/airflow/providers/amazon/aws/operators/step_function_start_execution.py @@ -19,9 +19,8 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults - from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook +from airflow.utils.decorators import apply_defaults class StepFunctionStartExecutionOperator(BaseOperator): diff --git a/airflow/providers/amazon/aws/sensors/step_function_execution.py b/airflow/providers/amazon/aws/sensors/step_function_execution.py index 62da535471a2a..676241f79bdd5 100644 --- a/airflow/providers/amazon/aws/sensors/step_function_execution.py +++ b/airflow/providers/amazon/aws/sensors/step_function_execution.py @@ -18,11 +18,10 @@ import json from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults -from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook - class StepFunctionExecutionSensor(BaseSensorOperator): """ diff --git a/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py b/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py index 2aad2854b4e78..8997df9fdaba8 100644 --- a/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py +++ b/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py @@ -21,8 +21,9 @@ from unittest import mock from unittest.mock import MagicMock -from airflow.providers.amazon.aws.operators.step_function_get_execution_output import \ - StepFunctionGetExecutionOutputOperator +from airflow.providers.amazon.aws.operators.step_function_get_execution_output import ( + StepFunctionGetExecutionOutputOperator, +) TASK_ID = 'step_function_get_execution_output' EXECUTION_ARN = 'arn:aws:states:us-east-1:123456789012:execution:'\ diff --git a/tests/providers/amazon/aws/operators/test_step_function_start_execution.py b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py index e6dfcd0b23e51..de2ca0e0e8f93 100644 --- a/tests/providers/amazon/aws/operators/test_step_function_start_execution.py +++ b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py @@ -21,8 +21,9 @@ from unittest import mock from unittest.mock import MagicMock -from airflow.providers.amazon.aws.operators.step_function_start_execution \ - import StepFunctionStartExecutionOperator +from airflow.providers.amazon.aws.operators.step_function_start_execution import ( + StepFunctionStartExecutionOperator, +) TASK_ID = 'step_function_start_execution_task' STATE_MACHINE_ARN = 'arn:aws:states:us-east-1:000000000000:stateMachine:pseudo-state-machine' From 0c7467a6339eeaac8b2511bf41afedc16a3069a5 Mon Sep 17 00:00:00 2001 From: Chauncy McCaughey Date: Wed, 6 May 2020 17:06:49 -0600 Subject: [PATCH 05/16] fixed problems identified by mypy --- airflow/providers/amazon/aws/hooks/step_function.py | 4 ++-- .../amazon/aws/operators/step_function_start_execution.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/step_function.py b/airflow/providers/amazon/aws/hooks/step_function.py index 5a9793d68b2e7..3817668d2954a 100644 --- a/airflow/providers/amazon/aws/hooks/step_function.py +++ b/airflow/providers/amazon/aws/hooks/step_function.py @@ -36,7 +36,7 @@ def __init__(self, region_name=None, *args, **kwargs): super().__init__(client_type='stepfunctions', *args, **kwargs) def start_execution(self, state_machine_arn: str, name: Optional[str] = None, - input: Union[Dict[str, any], str, None] = None): + input: Union[dict, str, None] = None): """ Start Execution of the State Machine. https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.start_execution @@ -57,7 +57,7 @@ def start_execution(self, state_machine_arn: str, name: Optional[str] = None, execution_args['name'] = name if input is not None: if isinstance(input, str): - execution_args['input'] = str + execution_args['input'] = input elif isinstance(input, dict): execution_args['input'] = json.dumps(input) diff --git a/airflow/providers/amazon/aws/operators/step_function_start_execution.py b/airflow/providers/amazon/aws/operators/step_function_start_execution.py index 63ed099097676..7997b654afe93 100644 --- a/airflow/providers/amazon/aws/operators/step_function_start_execution.py +++ b/airflow/providers/amazon/aws/operators/step_function_start_execution.py @@ -49,7 +49,7 @@ class StepFunctionStartExecutionOperator(BaseOperator): @apply_defaults def __init__(self, state_machine_arn: str, name: Optional[str] = None, - input: Union[Dict[str, any], str, None] = None, + input: Union[dict, str, None] = None, aws_conn_id='aws_default', region_name=None, *args, **kwargs): if kwargs.get('xcom_push') is not None: From 537b735d1b7f2676cc18b483f0feda50e4bbb459 Mon Sep 17 00:00:00 2001 From: Chauncy McCaughey Date: Wed, 6 May 2020 17:19:17 -0600 Subject: [PATCH 06/16] fixed flake8 unused import error --- airflow/providers/amazon/aws/hooks/step_function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/hooks/step_function.py b/airflow/providers/amazon/aws/hooks/step_function.py index 3817668d2954a..adf0a187ce443 100644 --- a/airflow/providers/amazon/aws/hooks/step_function.py +++ b/airflow/providers/amazon/aws/hooks/step_function.py @@ -16,7 +16,7 @@ # under the License. import json -from typing import Dict, Optional, Union +from typing import Optional, Union from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook From 8e451d45752b77bbfc6355cfeea76743b5a85a44 Mon Sep 17 00:00:00 2001 From: Chauncy McCaughey Date: Thu, 7 May 2020 07:36:52 -0600 Subject: [PATCH 07/16] fixed pylint errors --- .../amazon/aws/hooks/step_function.py | 18 +++++++++--------- .../step_function_get_execution_output.py | 2 +- .../operators/step_function_start_execution.py | 12 ++++++------ .../aws/sensors/step_function_execution.py | 3 +-- .../amazon/aws/hooks/test_step_function.py | 4 ++-- .../test_step_function_start_execution.py | 4 ++-- 6 files changed, 21 insertions(+), 22 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/step_function.py b/airflow/providers/amazon/aws/hooks/step_function.py index adf0a187ce443..868d91f850d26 100644 --- a/airflow/providers/amazon/aws/hooks/step_function.py +++ b/airflow/providers/amazon/aws/hooks/step_function.py @@ -36,7 +36,7 @@ def __init__(self, region_name=None, *args, **kwargs): super().__init__(client_type='stepfunctions', *args, **kwargs) def start_execution(self, state_machine_arn: str, name: Optional[str] = None, - input: Union[dict, str, None] = None): + state_machine_input: Union[dict, str, None] = None): """ Start Execution of the State Machine. https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.start_execution @@ -45,8 +45,8 @@ def start_execution(self, state_machine_arn: str, name: Optional[str] = None, :type state_machine_arn: str :param name: The name of the execution. :type name: Optional[str] - :param input: JSON data input to pass to the State Machine - :type input: Union[Dict[str, any], str, None] + :param state_machine_input: JSON data input to pass to the State Machine + :type state_machine_input: Union[Dict[str, any], str, None] :return: Execution ARN :rtype: str """ @@ -55,13 +55,13 @@ def start_execution(self, state_machine_arn: str, name: Optional[str] = None, } if name is not None: execution_args['name'] = name - if input is not None: - if isinstance(input, str): - execution_args['input'] = input - elif isinstance(input, dict): - execution_args['input'] = json.dumps(input) + if state_machine_input is not None: + if isinstance(state_machine_input, str): + execution_args['input'] = state_machine_input + elif isinstance(state_machine_input, dict): + execution_args['input'] = json.dumps(state_machine_input) - self.log.info(f'Executing Step Function State Machine: {state_machine_arn}') + self.log.info('Executing Step Function State Machine: %s', state_machine_arn) response = self.conn.start_execution(**execution_args) return response['executionArn'] if 'executionArn' in response else None diff --git a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py index 4a12b5e779465..32948476ad81d 100644 --- a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py +++ b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py @@ -59,6 +59,6 @@ def execute(self, context): if self.do_xcom_push: context['ti'].xcom_push(key='execution_output', value=execution_output) - self.log.info(f'Got State Machine Execution output for {self.execution_arn}') + self.log.info('Got State Machine Execution output for %s', self.execution_arn) return execution_output diff --git a/airflow/providers/amazon/aws/operators/step_function_start_execution.py b/airflow/providers/amazon/aws/operators/step_function_start_execution.py index 7997b654afe93..39c5d0a79493e 100644 --- a/airflow/providers/amazon/aws/operators/step_function_start_execution.py +++ b/airflow/providers/amazon/aws/operators/step_function_start_execution.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict, Optional, Union +from typing import Optional, Union from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -36,8 +36,8 @@ class StepFunctionStartExecutionOperator(BaseOperator): :type state_machine_arn: str :param name: The name of the execution. :type name: Optional[str] - :param input: JSON data input to pass to the State Machine - :type input: Union[Dict[str, any], str, None] + :param state_machine_input: JSON data input to pass to the State Machine + :type state_machine_input: Union[Dict[str, any], str, None] :param aws_conn_id: aws connection to uses :type aws_conn_id: str :param do_xcom_push: if True, execution_arn is pushed to XCom with key execution_arn. @@ -49,7 +49,7 @@ class StepFunctionStartExecutionOperator(BaseOperator): @apply_defaults def __init__(self, state_machine_arn: str, name: Optional[str] = None, - input: Union[dict, str, None] = None, + state_machine_input: Union[dict, str, None] = None, aws_conn_id='aws_default', region_name=None, *args, **kwargs): if kwargs.get('xcom_push') is not None: @@ -57,7 +57,7 @@ def __init__(self, state_machine_arn: str, name: Optional[str] = None, super().__init__(*args, **kwargs) self.state_machine_arn = state_machine_arn self.name = name - self.input = input + self.input = state_machine_input self.aws_conn_id = aws_conn_id self.region_name = region_name @@ -72,6 +72,6 @@ def execute(self, context): if self.do_xcom_push: context['ti'].xcom_push(key='execution_arn', value=execution_arn) - self.log.info(f'Started State Machine execution for {self.state_machine_arn}: {execution_arn}') + self.log.info('Started State Machine execution for %s: %s', self.state_machine_arn, execution_arn) return execution_arn diff --git a/airflow/providers/amazon/aws/sensors/step_function_execution.py b/airflow/providers/amazon/aws/sensors/step_function_execution.py index 676241f79bdd5..10c42595d848d 100644 --- a/airflow/providers/amazon/aws/sensors/step_function_execution.py +++ b/airflow/providers/amazon/aws/sensors/step_function_execution.py @@ -68,13 +68,12 @@ def poke(self, context): if state in self.INTERMEDIATE_STATES: return False - self.log.info(f'Doing xcom_push of output') + self.log.info('Doing xcom_push of output') self.xcom_push(context, 'output', output) return True def get_hook(self): """Create and return a StepFunctionHook""" if not self.hook: - self.log.info(f'region_name: {self.region_name}') self.hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) return self.hook diff --git a/tests/providers/amazon/aws/hooks/test_step_function.py b/tests/providers/amazon/aws/hooks/test_step_function.py index e025f75c24614..14942b10d1267 100644 --- a/tests/providers/amazon/aws/hooks/test_step_function.py +++ b/tests/providers/amazon/aws/hooks/test_step_function.py @@ -43,7 +43,7 @@ def test_start_execution(self): state_machine_arn = state_machine.get('stateMachineArn', None) - execution_arn = hook.start_execution(state_machine_arn=state_machine_arn, name=None, input={}) + execution_arn = hook.start_execution(state_machine_arn=state_machine_arn, name=None, state_machine_input={}) assert execution_arn is not None @@ -55,7 +55,7 @@ def test_describe_execution(self): state_machine_arn = state_machine.get('stateMachineArn', None) - execution_arn = hook.start_execution(state_machine_arn=state_machine_arn, name=None, input={}) + execution_arn = hook.start_execution(state_machine_arn=state_machine_arn, name=None, state_machine_input={}) response = hook.describe_execution(execution_arn) assert 'input' in response diff --git a/tests/providers/amazon/aws/operators/test_step_function_start_execution.py b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py index de2ca0e0e8f93..5f6c336521594 100644 --- a/tests/providers/amazon/aws/operators/test_step_function_start_execution.py +++ b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py @@ -44,7 +44,7 @@ def test_init(self): task_id=TASK_ID, state_machine_arn=STATE_MACHINE_ARN, name=NAME, - input=INPUT, + state_machine_input=INPUT, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME ) @@ -70,7 +70,7 @@ def test_execute(self, mock_hook): task_id=TASK_ID, state_machine_arn=STATE_MACHINE_ARN, name=NAME, - input=INPUT, + state_machine_input=INPUT, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME ) From 161c96c698bc76411ead1cacc2e68bf4c439f826 Mon Sep 17 00:00:00 2001 From: Chauncy McCaughey Date: Thu, 7 May 2020 07:56:28 -0600 Subject: [PATCH 08/16] fixed pylint error: line too long after refactor --- tests/providers/amazon/aws/hooks/test_step_function.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/providers/amazon/aws/hooks/test_step_function.py b/tests/providers/amazon/aws/hooks/test_step_function.py index 14942b10d1267..37fcf671ccc4f 100644 --- a/tests/providers/amazon/aws/hooks/test_step_function.py +++ b/tests/providers/amazon/aws/hooks/test_step_function.py @@ -43,7 +43,8 @@ def test_start_execution(self): state_machine_arn = state_machine.get('stateMachineArn', None) - execution_arn = hook.start_execution(state_machine_arn=state_machine_arn, name=None, state_machine_input={}) + execution_arn = hook.start_execution( + state_machine_arn=state_machine_arn, name=None, state_machine_input={}) assert execution_arn is not None @@ -55,7 +56,8 @@ def test_describe_execution(self): state_machine_arn = state_machine.get('stateMachineArn', None) - execution_arn = hook.start_execution(state_machine_arn=state_machine_arn, name=None, state_machine_input={}) + execution_arn = hook.start_execution( + state_machine_arn=state_machine_arn, name=None, state_machine_input={}) response = hook.describe_execution(execution_arn) assert 'input' in response From 5d3a0cab6d637026a0700485cb3bb0b13a9c6b42 Mon Sep 17 00:00:00 2001 From: Chauncy McCaughey Date: Mon, 11 May 2020 08:36:49 -0600 Subject: [PATCH 09/16] declare and cleanup function return types --- airflow/providers/amazon/aws/hooks/step_function.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/step_function.py b/airflow/providers/amazon/aws/hooks/step_function.py index 868d91f850d26..f0e10400d95ee 100644 --- a/airflow/providers/amazon/aws/hooks/step_function.py +++ b/airflow/providers/amazon/aws/hooks/step_function.py @@ -36,7 +36,7 @@ def __init__(self, region_name=None, *args, **kwargs): super().__init__(client_type='stepfunctions', *args, **kwargs) def start_execution(self, state_machine_arn: str, name: Optional[str] = None, - state_machine_input: Union[dict, str, None] = None): + state_machine_input: Union[dict, str, None] = None) -> str: """ Start Execution of the State Machine. https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.start_execution @@ -64,9 +64,9 @@ def start_execution(self, state_machine_arn: str, name: Optional[str] = None, self.log.info('Executing Step Function State Machine: %s', state_machine_arn) response = self.conn.start_execution(**execution_args) - return response['executionArn'] if 'executionArn' in response else None + return response.get('executionArn', None) - def describe_execution(self, execution_arn: str): + def describe_execution(self, execution_arn: str) -> dict: """ Describes a State Machine Execution https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.describe_execution From 074d30fbb1947fec66ab1f9fb6e22b7c9a09a85f Mon Sep 17 00:00:00 2001 From: Chauncy McCaughey Date: Mon, 11 May 2020 08:37:49 -0600 Subject: [PATCH 10/16] removed redundant xcom_push operations --- .../amazon/aws/operators/step_function_get_execution_output.py | 3 --- .../amazon/aws/operators/step_function_start_execution.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py index 32948476ad81d..f1aff45dcbd4b 100644 --- a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py +++ b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py @@ -56,9 +56,6 @@ def execute(self, context): execution_status = hook.describe_execution(self.execution_arn) execution_output = json.loads(execution_status['output']) if 'output' in execution_status else None - if self.do_xcom_push: - context['ti'].xcom_push(key='execution_output', value=execution_output) - self.log.info('Got State Machine Execution output for %s', self.execution_arn) return execution_output diff --git a/airflow/providers/amazon/aws/operators/step_function_start_execution.py b/airflow/providers/amazon/aws/operators/step_function_start_execution.py index 39c5d0a79493e..c4b239dce262e 100644 --- a/airflow/providers/amazon/aws/operators/step_function_start_execution.py +++ b/airflow/providers/amazon/aws/operators/step_function_start_execution.py @@ -69,9 +69,6 @@ def execute(self, context): if execution_arn is None: raise AirflowException(f'Failed to start State Machine execution for: {self.state_machine_arn}') - if self.do_xcom_push: - context['ti'].xcom_push(key='execution_arn', value=execution_arn) - self.log.info('Started State Machine execution for %s: %s', self.state_machine_arn, execution_arn) return execution_arn From f51f875528f7f94fb882cfc5ca1f1aad07a6ec31 Mon Sep 17 00:00:00 2001 From: Chauncy McCaughey Date: Mon, 11 May 2020 08:38:55 -0600 Subject: [PATCH 11/16] explicitly test boto3.client client_type --- tests/providers/amazon/aws/hooks/test_step_function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/amazon/aws/hooks/test_step_function.py b/tests/providers/amazon/aws/hooks/test_step_function.py index 37fcf671ccc4f..d919f93504925 100644 --- a/tests/providers/amazon/aws/hooks/test_step_function.py +++ b/tests/providers/amazon/aws/hooks/test_step_function.py @@ -33,7 +33,7 @@ class TestStepFunctionHook(unittest.TestCase): @mock_stepfunctions def test_get_conn_returns_a_boto3_connection(self): hook = StepFunctionHook(aws_conn_id='aws_default') - self.assertIsNotNone(hook.get_conn()) + self.assertEquals('stepfunctions', hook.get_conn().meta.service_model.service_name) @mock_stepfunctions def test_start_execution(self): From f3b041f27b065782340c27331e705f0f9faeb380 Mon Sep 17 00:00:00 2001 From: Chauncy McCaughey Date: Mon, 11 May 2020 08:39:58 -0600 Subject: [PATCH 12/16] use @parameterized.expand to elminate redunant test ops --- .../sensors/test_step_function_execution.py | 73 +++---------------- 1 file changed, 12 insertions(+), 61 deletions(-) diff --git a/tests/providers/amazon/aws/sensors/test_step_function_execution.py b/tests/providers/amazon/aws/sensors/test_step_function_execution.py index 23a255b2ae075..a653a66ddb128 100644 --- a/tests/providers/amazon/aws/sensors/test_step_function_execution.py +++ b/tests/providers/amazon/aws/sensors/test_step_function_execution.py @@ -20,6 +20,8 @@ from unittest import mock from unittest.mock import MagicMock +from parameterized import parameterized + from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.sensors.step_function_execution import StepFunctionExecutionSensor @@ -48,48 +50,11 @@ def test_init(self): self.assertEqual(AWS_CONN_ID, sensor.aws_conn_id) self.assertEqual(REGION_NAME, sensor.region_name) + @parameterized.expand([('FAILED',), ('TIMED_OUT',), ('ABORTED',)]) @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') - def test_failed(self, mock_hook): - hook_response = { - 'status': 'FAILED' - } - - hook_instance = mock_hook.return_value - hook_instance.describe_execution.return_value = hook_response - - sensor = StepFunctionExecutionSensor( - task_id=TASK_ID, - execution_arn=EXECUTION_ARN, - aws_conn_id=AWS_CONN_ID, - region_name=REGION_NAME - ) - - with self.assertRaises(AirflowException): - sensor.poke(self.mock_context) - - @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') - def test_timed_out(self, mock_hook): - hook_response = { - 'status': 'TIMED_OUT' - } - - hook_instance = mock_hook.return_value - hook_instance.describe_execution.return_value = hook_response - - sensor = StepFunctionExecutionSensor( - task_id=TASK_ID, - execution_arn=EXECUTION_ARN, - aws_conn_id=AWS_CONN_ID, - region_name=REGION_NAME - ) - - with self.assertRaises(AirflowException): - sensor.poke(self.mock_context) - - @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') - def test_aborted(self, mock_hook): + def test_exceptions(self, mock_status, mock_hook): hook_response = { - 'status': 'ABORTED' + 'status': mock_status } hook_instance = mock_hook.return_value @@ -105,28 +70,11 @@ def test_aborted(self, mock_hook): with self.assertRaises(AirflowException): sensor.poke(self.mock_context) + @parameterized.expand([('RUNNING',), ('SUCCEEDED',)]) @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') - def test_running(self, mock_hook): - hook_response = { - 'status': 'RUNNING' - } - - hook_instance = mock_hook.return_value - hook_instance.describe_execution.return_value = hook_response - - sensor = StepFunctionExecutionSensor( - task_id=TASK_ID, - execution_arn=EXECUTION_ARN, - aws_conn_id=AWS_CONN_ID, - region_name=REGION_NAME - ) - - self.assertFalse(sensor.poke(self.mock_context)) - - @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') - def test_succeeded(self, mock_hook): + def test_returns(self, mock_status, mock_hook): hook_response = { - 'status': 'SUCCEEDED' + 'status': mock_status } hook_instance = mock_hook.return_value @@ -139,4 +87,7 @@ def test_succeeded(self, mock_hook): region_name=REGION_NAME ) - self.assertTrue(sensor.poke(self.mock_context)) + if mock_status == 'RUNNING': + self.assertFalse(sensor.poke(self.mock_context)) + else: + self.assertTrue(sensor.poke(self.mock_context)) From aed57916013ca02b1ad365535ad4bb89e5e2f2cc Mon Sep 17 00:00:00 2001 From: Chauncy McCaughey Date: Mon, 11 May 2020 09:00:02 -0600 Subject: [PATCH 13/16] replace usage of deprecated assertEquals with assertEqual --- tests/providers/amazon/aws/hooks/test_step_function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/amazon/aws/hooks/test_step_function.py b/tests/providers/amazon/aws/hooks/test_step_function.py index d919f93504925..679d2e44039ad 100644 --- a/tests/providers/amazon/aws/hooks/test_step_function.py +++ b/tests/providers/amazon/aws/hooks/test_step_function.py @@ -33,7 +33,7 @@ class TestStepFunctionHook(unittest.TestCase): @mock_stepfunctions def test_get_conn_returns_a_boto3_connection(self): hook = StepFunctionHook(aws_conn_id='aws_default') - self.assertEquals('stepfunctions', hook.get_conn().meta.service_model.service_name) + self.assertEqual('stepfunctions', hook.get_conn().meta.service_model.service_name) @mock_stepfunctions def test_start_execution(self): From f083b99351de0b32bf94346e94605b4013da04f3 Mon Sep 17 00:00:00 2001 From: Chauncy McCaughey Date: Mon, 11 May 2020 10:57:10 -0600 Subject: [PATCH 14/16] separated unittests for SUCCEEDED and RUNNING statuses --- .../sensors/test_step_function_execution.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/tests/providers/amazon/aws/sensors/test_step_function_execution.py b/tests/providers/amazon/aws/sensors/test_step_function_execution.py index a653a66ddb128..237f8ef424cb2 100644 --- a/tests/providers/amazon/aws/sensors/test_step_function_execution.py +++ b/tests/providers/amazon/aws/sensors/test_step_function_execution.py @@ -70,11 +70,28 @@ def test_exceptions(self, mock_status, mock_hook): with self.assertRaises(AirflowException): sensor.poke(self.mock_context) - @parameterized.expand([('RUNNING',), ('SUCCEEDED',)]) @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') - def test_returns(self, mock_status, mock_hook): + def test_running(self, mock_hook): hook_response = { - 'status': mock_status + 'status': 'RUNNING' + } + + hook_instance = mock_hook.return_value + hook_instance.describe_execution.return_value = hook_response + + sensor = StepFunctionExecutionSensor( + task_id=TASK_ID, + execution_arn=EXECUTION_ARN, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME + ) + + self.assertFalse(sensor.poke(self.mock_context)) + + @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') + def test_succeeded(self, mock_hook): + hook_response = { + 'status': 'SUCCEEDED' } hook_instance = mock_hook.return_value @@ -87,7 +104,4 @@ def test_returns(self, mock_status, mock_hook): region_name=REGION_NAME ) - if mock_status == 'RUNNING': - self.assertFalse(sensor.poke(self.mock_context)) - else: - self.assertTrue(sensor.poke(self.mock_context)) + self.assertTrue(sensor.poke(self.mock_context)) From c46733d4ec0770e6f5d654ccd5aae4231ad28313 Mon Sep 17 00:00:00 2001 From: Chauncy McCaughey Date: Mon, 18 May 2020 07:18:55 -0600 Subject: [PATCH 15/16] removed unnecessary deprecation error for xcom_push --- .../amazon/aws/operators/step_function_get_execution_output.py | 2 -- .../amazon/aws/operators/step_function_start_execution.py | 2 -- airflow/providers/amazon/aws/sensors/step_function_execution.py | 2 -- 3 files changed, 6 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py index f1aff45dcbd4b..23863af15a86d 100644 --- a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py +++ b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py @@ -43,8 +43,6 @@ class StepFunctionGetExecutionOutputOperator(BaseOperator): @apply_defaults def __init__(self, execution_arn: str, aws_conn_id='aws_default', region_name=None, *args, **kwargs): - if kwargs.get('xcom_push') is not None: - raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead") super().__init__(*args, **kwargs) self.execution_arn = execution_arn self.aws_conn_id = aws_conn_id diff --git a/airflow/providers/amazon/aws/operators/step_function_start_execution.py b/airflow/providers/amazon/aws/operators/step_function_start_execution.py index c4b239dce262e..f5ea75ca3994d 100644 --- a/airflow/providers/amazon/aws/operators/step_function_start_execution.py +++ b/airflow/providers/amazon/aws/operators/step_function_start_execution.py @@ -52,8 +52,6 @@ def __init__(self, state_machine_arn: str, name: Optional[str] = None, state_machine_input: Union[dict, str, None] = None, aws_conn_id='aws_default', region_name=None, *args, **kwargs): - if kwargs.get('xcom_push') is not None: - raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead") super().__init__(*args, **kwargs) self.state_machine_arn = state_machine_arn self.name = name diff --git a/airflow/providers/amazon/aws/sensors/step_function_execution.py b/airflow/providers/amazon/aws/sensors/step_function_execution.py index 10c42595d848d..0cc3caf271806 100644 --- a/airflow/providers/amazon/aws/sensors/step_function_execution.py +++ b/airflow/providers/amazon/aws/sensors/step_function_execution.py @@ -49,8 +49,6 @@ class StepFunctionExecutionSensor(BaseSensorOperator): @apply_defaults def __init__(self, execution_arn: str, aws_conn_id='aws_default', region_name=None, *args, **kwargs): - if kwargs.get('xcom_push') is not None: - raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead") super().__init__(*args, **kwargs) self.execution_arn = execution_arn self.aws_conn_id = aws_conn_id From e10608a5957fe01f2f43e0c6f988f7da0d438e8b Mon Sep 17 00:00:00 2001 From: Chauncy McCaughey Date: Mon, 18 May 2020 07:45:44 -0600 Subject: [PATCH 16/16] removed unused AirflowException import --- .../amazon/aws/operators/step_function_get_execution_output.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py index 23863af15a86d..2ef531c782b3f 100644 --- a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py +++ b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py @@ -17,7 +17,6 @@ import json -from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook from airflow.utils.decorators import apply_defaults