From 6d216f5b3379896bbd414ef9d1dd65507a3b9e44 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 14 Jun 2022 17:39:42 -0700 Subject: [PATCH] Convert ECS Fargate Sample DAG to System Test --- .../aws/example_dags/example_ecs_fargate.py | 62 -------- airflow/providers/amazon/aws/operators/ecs.py | 7 +- .../operators/ecs.rst | 2 +- .../amazon/aws/operators/test_ecs.py | 7 +- .../amazon/aws/example_ecs_fargate.py | 144 ++++++++++++++++++ .../providers/amazon/aws/utils/__init__.py | 23 ++- 6 files changed, 177 insertions(+), 68 deletions(-) delete mode 100644 airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py create mode 100644 tests/system/providers/amazon/aws/example_ecs_fargate.py diff --git a/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py b/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py deleted file mode 100644 index 1e48367429faa..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py +++ /dev/null @@ -1,62 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import os -from datetime import datetime - -from airflow import DAG -from airflow.providers.amazon.aws.operators.ecs import EcsOperator - -with DAG( - dag_id='example_ecs_fargate', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - # [START howto_operator_ecs] - hello_world = EcsOperator( - task_id="hello_world", - cluster=os.environ.get("CLUSTER_NAME", "existing_cluster_name"), - task_definition=os.environ.get("TASK_DEFINITION", "existing_task_definition_name"), - launch_type="FARGATE", - aws_conn_id="aws_ecs", - overrides={ - "containerOverrides": [ - { - "name": "hello-world-container", - "command": ["echo", "hello", "world"], - }, - ], - }, - network_configuration={ - "awsvpcConfiguration": { - "securityGroups": [os.environ.get("SECURITY_GROUP_ID", "sg-123abc")], - "subnets": [os.environ.get("SUBNET_ID", "subnet-123456ab")], - }, - }, - tags={ - "Customer": "X", - "Project": "Y", - "Application": "Z", - "Version": "0.0.1", - "Environment": "Development", - }, - awslogs_group="/ecs/hello-world", - awslogs_stream_prefix="prefix_b/hello-world-container", - ) - # [END howto_operator_ecs] diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index d1112edf445ee..7967fe0d139ef 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -219,7 +219,12 @@ class EcsOperator(BaseOperator): """ ui_color = '#f0ede4' - template_fields: Sequence[str] = ('overrides',) + template_fields: Sequence[str] = ( + 'cluster', + 'task_definition', + 'overrides', + 'network_configuration', + ) template_fields_renderers = { "overrides": "json", "network_configuration": "json", diff --git a/docs/apache-airflow-providers-amazon/operators/ecs.rst b/docs/apache-airflow-providers-amazon/operators/ecs.rst index 6a1bd0fce55d9..b4dae280013ed 100644 --- a/docs/apache-airflow-providers-amazon/operators/ecs.rst +++ b/docs/apache-airflow-providers-amazon/operators/ecs.rst @@ -63,7 +63,7 @@ The parameters you need to configure for this Operator will depend upon which `` :end-before: [END howto_operator_ecs] -.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ecs_fargate.py :language: python :dedent: 4 :start-after: [START howto_operator_ecs] diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index 6844b7461d331..701f4d78530d9 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -100,7 +100,12 @@ def test_init(self): self.aws_hook_mock.assert_called_once() def test_template_fields_overrides(self): - assert self.ecs.template_fields == ('overrides',) + assert self.ecs.template_fields == ( + 'cluster', + 'task_definition', + 'overrides', + 'network_configuration', + ) @parameterized.expand( [ diff --git a/tests/system/providers/amazon/aws/example_ecs_fargate.py b/tests/system/providers/amazon/aws/example_ecs_fargate.py new file mode 100644 index 0000000000000..e365c934ad745 --- /dev/null +++ b/tests/system/providers/amazon/aws/example_ecs_fargate.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import datetime + +import boto3 + +from airflow import DAG +from airflow.decorators import task +from airflow.models.baseoperator import chain +from airflow.providers.amazon.aws.operators.ecs import EcsOperator +from airflow.utils.trigger_rule import TriggerRule +from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder + +DAG_ID = 'example_ecs_fargate' + +# Externally fetched variables: +SUBNETS_KEY = 'SUBNETS' # At least one public subnet is required. +SECURITY_GROUPS_KEY = 'SECURITY_GROUPS' + +sys_test_context_task = ( + SystemTestContextBuilder() + .add_variable(SUBNETS_KEY, split_string=True) + .add_variable(SECURITY_GROUPS_KEY, split_string=True) + .build() +) + + +@task +def create_cluster(cluster_name: str) -> None: + """Creates an ECS cluster.""" + boto3.client('ecs').create_cluster(clusterName=cluster_name) + + +@task +def register_task_definition(task_name: str, container_name: str) -> str: + """Creates a Task Definition.""" + response = boto3.client('ecs').register_task_definition( + family=task_name, + # CPU and Memory are required for Fargate and are set to the lowest currently allowed values. + cpu='256', + memory='512', + containerDefinitions=[ + { + 'name': container_name, + 'image': 'ubuntu', + 'workingDirectory': '/usr/bin', + 'entryPoint': ['sh', '-c'], + 'command': ['ls'], + } + ], + requiresCompatibilities=['FARGATE'], + networkMode='awsvpc', + ) + + return response['taskDefinition']['taskDefinitionArn'] + + +@task(trigger_rule=TriggerRule.ALL_DONE) +def delete_task_definition(task_definition_arn: str) -> None: + """Deletes the Task Definition.""" + boto3.client('ecs').deregister_task_definition(taskDefinition=task_definition_arn) + + +@task(trigger_rule=TriggerRule.ALL_DONE) +def delete_cluster(cluster_name: str) -> None: + """Deletes the ECS cluster.""" + boto3.client('ecs').delete_cluster(cluster=cluster_name) + + +with DAG( + dag_id=DAG_ID, + schedule_interval='@once', + start_date=datetime(2021, 1, 1), + tags=['example'], + catchup=False, +) as dag: + test_context = sys_test_context_task() + env_id = test_context[ENV_ID_KEY] + + cluster_name = f'{env_id}-test-cluster' + container_name = f'{env_id}-test-container' + task_definition_name = f'{env_id}-test-definition' + + create_task_definition = register_task_definition(task_definition_name, container_name) + + # [START howto_operator_ecs] + hello_world = EcsOperator( + task_id='hello_world', + cluster=cluster_name, + task_definition=task_definition_name, + launch_type='FARGATE', + overrides={ + 'containerOverrides': [ + { + 'name': container_name, + 'command': ['echo', 'hello', 'world'], + }, + ], + }, + network_configuration={ + 'awsvpcConfiguration': { + 'subnets': test_context[SUBNETS_KEY], + 'securityGroups': test_context[SECURITY_GROUPS_KEY], + }, + }, + ) + # [END howto_operator_ecs] + + chain( + # TEST SETUP + test_context, + create_cluster(cluster_name), + create_task_definition, + # TEST BODY + hello_world, + # TEST TEARDOWN + delete_task_definition(create_task_definition), + delete_cluster(cluster_name), + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/system/providers/amazon/aws/utils/__init__.py b/tests/system/providers/amazon/aws/utils/__init__.py index 0c2e16a40e193..f554d6d1cb5d9 100644 --- a/tests/system/providers/amazon/aws/utils/__init__.py +++ b/tests/system/providers/amazon/aws/utils/__init__.py @@ -111,14 +111,29 @@ class SystemTestContextBuilder: as any external resources requested (e.g.g IAM Roles, VPC, etc)""" def __init__(self): - self.variables = [] + self.variables = set() + self.variables_to_split = {} self.variable_defaults = {} self.test_name = _get_test_name() self.env_id = set_env_id() - def add_variable(self, variable_name: str, **kwargs): + def add_variable( + self, + variable_name: str, + split_string: bool = False, + delimiter: Optional[str] = None, + **kwargs, + ): """Register a variable to fetch from environment or cloud parameter store""" - self.variables.append(variable_name) + if variable_name in self.variables: + raise ValueError(f'Variable name {variable_name} already exists in the fetched variables list.') + if delimiter and not split_string: + raise ValueError(f'Variable {variable_name} has a delimiter but split_string is set to False.') + + self.variables.add(variable_name) + if split_string: + self.variables_to_split[variable_name] = delimiter + # default_value is accepted via kwargs so that it is completely optional and no # default value needs to be provided in the method stub (otherwise we wouldn't # be able to tell the difference between our default value and one provided by @@ -139,6 +154,8 @@ def variable_fetcher(**kwargs): for variable in self.variables: default_value = self.variable_defaults.get(variable, None) value = fetch_variable(variable, default_value, test_name=self.test_name) + if variable in self.variables_to_split: + value = value.split(self.variables_to_split[variable]) ti.xcom_push(variable, value) # Fetch/generate ENV_ID and store it in XCOM