Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Start an AWS Step Functions state machine execution
To start a new AWS Step Functions state machine execution you can use
:class:`~airflow.providers.amazon.aws.operators.step_function.StepFunctionStartExecutionOperator`.

.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_step_functions.py
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_step_functions.py
:language: python
:dedent: 4
:start-after: [START howto_operator_step_function_start_execution]
Expand All @@ -53,7 +53,7 @@ Get an AWS Step Functions execution output
To fetch the output from an AWS Step Function state machine execution you can
use :class:`~airflow.providers.amazon.aws.operators.step_function.StepFunctionGetExecutionOutputOperator`.

.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_step_functions.py
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_step_functions.py
:language: python
:dedent: 4
:start-after: [START howto_operator_step_function_get_execution_output]
Expand All @@ -70,7 +70,7 @@ Wait on an AWS Step Functions state machine execution state
To wait on the state of an AWS Step Function state machine execution until it reaches a terminal state you can
use :class:`~airflow.providers.amazon.aws.sensors.step_function.StepFunctionExecutionSensor`.

.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_step_functions.py
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_step_functions.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_step_function_execution]
Expand Down
119 changes: 119 additions & 0 deletions tests/system/providers/amazon/aws/example_step_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from datetime import datetime

from airflow import DAG
from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
from airflow.providers.amazon.aws.operators.step_function import (
StepFunctionGetExecutionOutputOperator,
StepFunctionStartExecutionOperator,
)
from airflow.providers.amazon.aws.sensors.step_function import StepFunctionExecutionSensor
from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder

DAG_ID = 'example_step_functions'

# Externally fetched variables:
ROLE_ARN_KEY = 'ROLE_ARN'

sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()

STATE_MACHINE_DEFINITION = {
"StartAt": "Wait",
"States": {"Wait": {"Type": "Wait", "Seconds": 7, "Next": "Success"}, "Success": {"Type": "Succeed"}},
}


@task
def create_state_machine(env_id, role_arn):
# Create a Step Functions State Machine and return the ARN for use by
# downstream tasks.
return (
StepFunctionHook()
.get_conn()
.create_state_machine(
name=f'{DAG_ID}_{env_id}',
definition=json.dumps(STATE_MACHINE_DEFINITION),
roleArn=role_arn,
)['stateMachineArn']
)


@task
def delete_state_machine(state_machine_arn):
StepFunctionHook().get_conn().delete_state_machine(stateMachineArn=state_machine_arn)


with DAG(
dag_id=DAG_ID,
schedule_interval='@once',
start_date=datetime(2021, 1, 1),
tags=['example'],
catchup=False,
) as dag:

# This context contains the ENV_ID and any env variables requested when the
# task was built above. Access the info as you would any other TaskFlow task.
test_context = sys_test_context_task()
env_id = test_context[ENV_ID_KEY]
role_arn = test_context[ROLE_ARN_KEY]

state_machine_arn = create_state_machine(env_id, role_arn)

# [START howto_operator_step_function_start_execution]
start_execution = StepFunctionStartExecutionOperator(
task_id='start_execution', state_machine_arn=state_machine_arn
)
# [END howto_operator_step_function_start_execution]

# [START howto_sensor_step_function_execution]
wait_for_execution = StepFunctionExecutionSensor(
task_id='wait_for_execution', execution_arn=start_execution.output
)
# [END howto_sensor_step_function_execution]

# [START howto_operator_step_function_get_execution_output]
get_execution_output = StepFunctionGetExecutionOutputOperator(
task_id='get_execution_output', execution_arn=start_execution.output
)
# [END howto_operator_step_function_get_execution_output]

chain(
# TEST SETUP
test_context,
state_machine_arn,
# TEST BODY
start_execution,
wait_for_execution,
get_execution_output,
# TEST TEARDOWN
delete_state_machine(state_machine_arn),
)

from tests.system.utils.watcher import watcher

# This test needs watcher in order to properly mark success/failure
# when "tearDown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()

from tests.system.utils import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
test_run = get_test_run(dag)
55 changes: 50 additions & 5 deletions tests/system/providers/amazon/aws/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
from botocore.client import BaseClient
from botocore.exceptions import NoCredentialsError

from airflow.decorators import task

ENV_ID_ENVIRON_KEY: str = 'SYSTEM_TESTS_ENV_ID'
ENV_ID_KEY: str = 'ENV_ID'
DEFAULT_ENV_ID_PREFIX: str = 'env'
DEFAULT_ENV_ID_LEN: int = 8
DEFAULT_ENV_ID: str = f'{DEFAULT_ENV_ID_PREFIX}{str(uuid4())[:DEFAULT_ENV_ID_LEN]}'
Expand Down Expand Up @@ -76,19 +79,19 @@ def _validate_env_id(env_id: str) -> str:
return env_id.lower()


def _fetch_from_ssm(key: str) -> str:
def _fetch_from_ssm(key: str, test_name: Optional[str] = None) -> str:
"""
Test values are stored in the SSM Value as a JSON-encoded dict of key/value pairs.

:param key: The key to search for within the returned Parameter Value.
:return: The value of the provided key from SSM
"""
test_name: str = _get_test_name()
_test_name: str = test_name if test_name else _get_test_name()
ssm_client: BaseClient = boto3.client('ssm')
value: str = ''

try:
value = json.loads(ssm_client.get_parameter(Name=test_name)['Parameter']['Value'])[key]
value = json.loads(ssm_client.get_parameter(Name=_test_name)['Parameter']['Value'])[key]
# Since a default value after the SSM check is allowed, these exceptions should not stop execution.
except NoCredentialsError:
# No boto credentials found.
Expand All @@ -102,7 +105,49 @@ def _fetch_from_ssm(key: str) -> str:
return value


def fetch_variable(key: str, default_value: Optional[str] = None) -> str:
class SystemTestContextBuilder:
"""This builder class ultimately constructs a TaskFlow task which is run at
runtime (task execution time). This task generates and stores the test ENV_ID as well
as any external resources requested (e.g.g IAM Roles, VPC, etc)"""

def __init__(self):
self.variables = []
self.variable_defaults = {}
self.test_name = _get_test_name()
self.env_id = set_env_id()

def add_variable(self, variable_name: str, **kwargs):
"""Register a variable to fetch from environment or cloud parameter store"""
self.variables.append(variable_name)
# default_value is accepted via kwargs so that it is completely optional and no
# default value needs to be provided in the method stub (otherwise we wouldn't
# be able to tell the difference between our default value and one provided by
# the caller)
if 'default_value' in kwargs:
self.variable_defaults[variable_name] = kwargs['default_value']

return self # Builder recipe; returning self allows chaining

def build(self):
"""Build and return a TaskFlow task which will create an env_id and
fetch requested variables. Storing everything in xcom for downstream
tasks to use."""

@task
def variable_fetcher(**kwargs):
ti = kwargs['ti']
for variable in self.variables:
default_value = self.variable_defaults.get(variable, None)
value = fetch_variable(variable, default_value, test_name=self.test_name)
ti.xcom_push(variable, value)

# Fetch/generate ENV_ID and store it in XCOM
ti.xcom_push(ENV_ID_KEY, self.env_id)

return variable_fetcher


def fetch_variable(key: str, default_value: Optional[str] = None, test_name: Optional[str] = None) -> str:
"""
Given a Parameter name: first check for an existing Environment Variable,
then check SSM for a value. If neither are available, fall back on the
Expand All @@ -113,7 +158,7 @@ def fetch_variable(key: str, default_value: Optional[str] = None) -> str:
:return: The value of the parameter.
"""

value: Optional[str] = os.getenv(key, _fetch_from_ssm(key)) or default_value
value: Optional[str] = os.getenv(key, _fetch_from_ssm(key, test_name=test_name)) or default_value
if not value:
raise ValueError(NO_VALUE_MSG.format(key=key))
return value
Expand Down