diff --git a/airflow/contrib/operators/ecs_operator.py b/airflow/contrib/operators/ecs_operator.py index c85ae15b771ec..8bad285ffd8c1 100644 --- a/airflow/contrib/operators/ecs_operator.py +++ b/airflow/contrib/operators/ecs_operator.py @@ -45,6 +45,15 @@ class ECSOperator(BaseOperator): :type region_name: str :param launch_type: the launch type on which to run your task ('EC2' or 'FARGATE') :type launch_type: str + :param group: the name of the task group associated with the task + :type group: str + :param placement_constraints: an array of placement constraint objects to use for + the task + :type placement_constraints: list + :param platform_version: the platform version on which your task is running + :type platform_version: str + :param network_configuration: the network configuration for the task + :type network_configuration: dict """ ui_color = '#f0ede4' @@ -54,7 +63,9 @@ class ECSOperator(BaseOperator): @apply_defaults def __init__(self, task_definition, cluster, overrides, - aws_conn_id=None, region_name=None, launch_type='EC2', **kwargs): + aws_conn_id=None, region_name=None, launch_type='EC2', + group=None, placement_constraints=None, platform_version='LATEST', + network_configuration=None, **kwargs): super(ECSOperator, self).__init__(**kwargs) self.aws_conn_id = aws_conn_id @@ -63,6 +74,10 @@ def __init__(self, task_definition, cluster, overrides, self.cluster = cluster self.overrides = overrides self.launch_type = launch_type + self.group = group + self.placement_constraints = placement_constraints + self.platform_version = platform_version + self.network_configuration = network_configuration self.hook = self.get_hook() @@ -78,13 +93,21 @@ def execute(self, context): region_name=self.region_name ) - response = self.client.run_task( - cluster=self.cluster, - taskDefinition=self.task_definition, - overrides=self.overrides, - startedBy=self.owner, - launchType=self.launch_type - ) + run_opts = { + 'cluster': self.cluster, + 'taskDefinition': self.task_definition, + 'overrides': self.overrides, + 'startedBy': self.owner, + 'launchType': self.launch_type, + 'platformVersion': self.platform_version, + } + if self.group is not None: + run_opts['group'] = self.group + if self.placement_constraints is not None: + run_opts['placementConstraints'] = self.placement_constraints + if self.network_configuration is not None: + run_opts['networkConfiguration'] = self.network_configuration + response = self.client.run_task(**run_opts) failures = response['failures'] if len(failures) > 0: diff --git a/tests/contrib/operators/test_ecs_operator.py b/tests/contrib/operators/test_ecs_operator.py index 43a816da4a1ea..842db1a44a876 100644 --- a/tests/contrib/operators/test_ecs_operator.py +++ b/tests/contrib/operators/test_ecs_operator.py @@ -69,7 +69,20 @@ def setUp(self, aws_hook_mock): cluster='c', overrides={}, aws_conn_id=None, - region_name='eu-west-1') + region_name='eu-west-1', + group='group', + placement_constraints=[ + { + 'expression': 'attribute:ecs.instance-type =~ t2.*', + 'type': 'memberOf' + } + ], + network_configuration={ + 'awsvpcConfiguration': { + 'securityGroups': ['sg-123abc'] + } + } + ) def test_init(self): @@ -100,7 +113,20 @@ def test_execute_without_failures(self, check_mock, wait_mock): launchType='EC2', overrides={}, startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' - taskDefinition='t' + taskDefinition='t', + group='group', + placementConstraints=[ + { + 'expression': 'attribute:ecs.instance-type =~ t2.*', + 'type': 'memberOf' + } + ], + platformVersion='LATEST', + networkConfiguration={ + 'awsvpcConfiguration': { + 'securityGroups': ['sg-123abc'] + } + } ) wait_mock.assert_called_once_with() @@ -123,7 +149,20 @@ def test_execute_with_failures(self): launchType='EC2', overrides={}, startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' - taskDefinition='t' + taskDefinition='t', + group='group', + placementConstraints=[ + { + 'expression': 'attribute:ecs.instance-type =~ t2.*', + 'type': 'memberOf' + } + ], + platformVersion='LATEST', + networkConfiguration={ + 'awsvpcConfiguration': { + 'securityGroups': ['sg-123abc'] + } + } ) def test_wait_end_tasks(self):