diff --git a/airflow/providers/amazon/aws/example_dags/example_step_functions.py b/airflow/providers/amazon/aws/example_dags/example_step_functions.py deleted file mode 100644 index 02763e3ea13f1..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_step_functions.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from datetime import datetime -from os import environ - -from airflow import DAG -from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.operators.step_function import ( - StepFunctionGetExecutionOutputOperator, - StepFunctionStartExecutionOperator, -) -from airflow.providers.amazon.aws.sensors.step_function import StepFunctionExecutionSensor - -STEP_FUNCTIONS_STATE_MACHINE_ARN = environ.get('STEP_FUNCTIONS_STATE_MACHINE_ARN', 'state_machine_arn') - -with DAG( - dag_id='example_step_functions', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - # [START howto_operator_step_function_start_execution] - start_execution = StepFunctionStartExecutionOperator( - task_id='start_execution', state_machine_arn=STEP_FUNCTIONS_STATE_MACHINE_ARN - ) - # [END howto_operator_step_function_start_execution] - - # [START howto_sensor_step_function_execution] - wait_for_execution = StepFunctionExecutionSensor( - task_id='wait_for_execution', execution_arn=start_execution.output - ) - # [END howto_sensor_step_function_execution] - - # [START howto_operator_step_function_get_execution_output] - get_execution_output = StepFunctionGetExecutionOutputOperator( - task_id='get_execution_output', execution_arn=start_execution.output - ) - # [END howto_operator_step_function_get_execution_output] - - chain(start_execution, wait_for_execution, get_execution_output) diff --git a/docs/apache-airflow-providers-amazon/operators/step_functions.rst b/docs/apache-airflow-providers-amazon/operators/step_functions.rst index e29d3194f6908..91984dc5986b6 100644 --- a/docs/apache-airflow-providers-amazon/operators/step_functions.rst +++ b/docs/apache-airflow-providers-amazon/operators/step_functions.rst @@ -39,7 +39,7 @@ Start an AWS Step Functions state machine execution To start a new AWS Step Functions state machine execution you can use :class:`~airflow.providers.amazon.aws.operators.step_function.StepFunctionStartExecutionOperator`. -.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_step_functions.py +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_step_functions.py :language: python :dedent: 4 :start-after: [START howto_operator_step_function_start_execution] @@ -53,7 +53,7 @@ Get an AWS Step Functions execution output To fetch the output from an AWS Step Function state machine execution you can use :class:`~airflow.providers.amazon.aws.operators.step_function.StepFunctionGetExecutionOutputOperator`. -.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_step_functions.py +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_step_functions.py :language: python :dedent: 4 :start-after: [START howto_operator_step_function_get_execution_output] @@ -70,7 +70,7 @@ Wait on an AWS Step Functions state machine execution state To wait on the state of an AWS Step Function state machine execution until it reaches a terminal state you can use :class:`~airflow.providers.amazon.aws.sensors.step_function.StepFunctionExecutionSensor`. -.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_step_functions.py +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_step_functions.py :language: python :dedent: 4 :start-after: [START howto_sensor_step_function_execution] diff --git a/tests/system/providers/amazon/aws/example_step_functions.py b/tests/system/providers/amazon/aws/example_step_functions.py new file mode 100644 index 0000000000000..98d6b7a7436b4 --- /dev/null +++ b/tests/system/providers/amazon/aws/example_step_functions.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +from datetime import datetime + +from airflow import DAG +from airflow.decorators import task +from airflow.models.baseoperator import chain +from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook +from airflow.providers.amazon.aws.operators.step_function import ( + StepFunctionGetExecutionOutputOperator, + StepFunctionStartExecutionOperator, +) +from airflow.providers.amazon.aws.sensors.step_function import StepFunctionExecutionSensor +from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder + +DAG_ID = 'example_step_functions' + +# Externally fetched variables: +ROLE_ARN_KEY = 'ROLE_ARN' + +sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build() + +STATE_MACHINE_DEFINITION = { + "StartAt": "Wait", + "States": {"Wait": {"Type": "Wait", "Seconds": 7, "Next": "Success"}, "Success": {"Type": "Succeed"}}, +} + + +@task +def create_state_machine(env_id, role_arn): + # Create a Step Functions State Machine and return the ARN for use by + # downstream tasks. + return ( + StepFunctionHook() + .get_conn() + .create_state_machine( + name=f'{DAG_ID}_{env_id}', + definition=json.dumps(STATE_MACHINE_DEFINITION), + roleArn=role_arn, + )['stateMachineArn'] + ) + + +@task +def delete_state_machine(state_machine_arn): + StepFunctionHook().get_conn().delete_state_machine(stateMachineArn=state_machine_arn) + + +with DAG( + dag_id=DAG_ID, + schedule_interval='@once', + start_date=datetime(2021, 1, 1), + tags=['example'], + catchup=False, +) as dag: + + # This context contains the ENV_ID and any env variables requested when the + # task was built above. Access the info as you would any other TaskFlow task. + test_context = sys_test_context_task() + env_id = test_context[ENV_ID_KEY] + role_arn = test_context[ROLE_ARN_KEY] + + state_machine_arn = create_state_machine(env_id, role_arn) + + # [START howto_operator_step_function_start_execution] + start_execution = StepFunctionStartExecutionOperator( + task_id='start_execution', state_machine_arn=state_machine_arn + ) + # [END howto_operator_step_function_start_execution] + + # [START howto_sensor_step_function_execution] + wait_for_execution = StepFunctionExecutionSensor( + task_id='wait_for_execution', execution_arn=start_execution.output + ) + # [END howto_sensor_step_function_execution] + + # [START howto_operator_step_function_get_execution_output] + get_execution_output = StepFunctionGetExecutionOutputOperator( + task_id='get_execution_output', execution_arn=start_execution.output + ) + # [END howto_operator_step_function_get_execution_output] + + chain( + # TEST SETUP + test_context, + state_machine_arn, + # TEST BODY + start_execution, + wait_for_execution, + get_execution_output, + # TEST TEARDOWN + delete_state_machine(state_machine_arn), + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/system/providers/amazon/aws/utils/__init__.py b/tests/system/providers/amazon/aws/utils/__init__.py index 09b7fff0a18b1..38f606d0fbde9 100644 --- a/tests/system/providers/amazon/aws/utils/__init__.py +++ b/tests/system/providers/amazon/aws/utils/__init__.py @@ -26,7 +26,10 @@ from botocore.client import BaseClient from botocore.exceptions import NoCredentialsError +from airflow.decorators import task + ENV_ID_ENVIRON_KEY: str = 'SYSTEM_TESTS_ENV_ID' +ENV_ID_KEY: str = 'ENV_ID' DEFAULT_ENV_ID_PREFIX: str = 'env' DEFAULT_ENV_ID_LEN: int = 8 DEFAULT_ENV_ID: str = f'{DEFAULT_ENV_ID_PREFIX}{str(uuid4())[:DEFAULT_ENV_ID_LEN]}' @@ -76,19 +79,19 @@ def _validate_env_id(env_id: str) -> str: return env_id.lower() -def _fetch_from_ssm(key: str) -> str: +def _fetch_from_ssm(key: str, test_name: Optional[str] = None) -> str: """ Test values are stored in the SSM Value as a JSON-encoded dict of key/value pairs. :param key: The key to search for within the returned Parameter Value. :return: The value of the provided key from SSM """ - test_name: str = _get_test_name() + _test_name: str = test_name if test_name else _get_test_name() ssm_client: BaseClient = boto3.client('ssm') value: str = '' try: - value = json.loads(ssm_client.get_parameter(Name=test_name)['Parameter']['Value'])[key] + value = json.loads(ssm_client.get_parameter(Name=_test_name)['Parameter']['Value'])[key] # Since a default value after the SSM check is allowed, these exceptions should not stop execution. except NoCredentialsError: # No boto credentials found. @@ -102,7 +105,49 @@ def _fetch_from_ssm(key: str) -> str: return value -def fetch_variable(key: str, default_value: Optional[str] = None) -> str: +class SystemTestContextBuilder: + """This builder class ultimately constructs a TaskFlow task which is run at + runtime (task execution time). This task generates and stores the test ENV_ID as well + as any external resources requested (e.g.g IAM Roles, VPC, etc)""" + + def __init__(self): + self.variables = [] + self.variable_defaults = {} + self.test_name = _get_test_name() + self.env_id = set_env_id() + + def add_variable(self, variable_name: str, **kwargs): + """Register a variable to fetch from environment or cloud parameter store""" + self.variables.append(variable_name) + # default_value is accepted via kwargs so that it is completely optional and no + # default value needs to be provided in the method stub (otherwise we wouldn't + # be able to tell the difference between our default value and one provided by + # the caller) + if 'default_value' in kwargs: + self.variable_defaults[variable_name] = kwargs['default_value'] + + return self # Builder recipe; returning self allows chaining + + def build(self): + """Build and return a TaskFlow task which will create an env_id and + fetch requested variables. Storing everything in xcom for downstream + tasks to use.""" + + @task + def variable_fetcher(**kwargs): + ti = kwargs['ti'] + for variable in self.variables: + default_value = self.variable_defaults.get(variable, None) + value = fetch_variable(variable, default_value, test_name=self.test_name) + ti.xcom_push(variable, value) + + # Fetch/generate ENV_ID and store it in XCOM + ti.xcom_push(ENV_ID_KEY, self.env_id) + + return variable_fetcher + + +def fetch_variable(key: str, default_value: Optional[str] = None, test_name: Optional[str] = None) -> str: """ Given a Parameter name: first check for an existing Environment Variable, then check SSM for a value. If neither are available, fall back on the @@ -113,7 +158,7 @@ def fetch_variable(key: str, default_value: Optional[str] = None) -> str: :return: The value of the parameter. """ - value: Optional[str] = os.getenv(key, _fetch_from_ssm(key)) or default_value + value: Optional[str] = os.getenv(key, _fetch_from_ssm(key, test_name=test_name)) or default_value if not value: raise ValueError(NO_VALUE_MSG.format(key=key)) return value