Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions airflow/contrib/operators/ecs_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ class ECSOperator(BaseOperator):
:type region_name: str
:param launch_type: the launch type on which to run your task ('EC2' or 'FARGATE')
:type launch_type: str
:param group: the name of the task group associated with the task
:type group: str
:param placement_constraints: an array of placement constraint objects to use for
the task
:type placement_constraints: list
:param platform_version: the platform version on which your task is running
:type platform_version: str
:param network_configuration: the network configuration for the task
:type network_configuration: dict
"""

ui_color = '#f0ede4'
Expand All @@ -54,7 +63,9 @@ class ECSOperator(BaseOperator):

@apply_defaults
def __init__(self, task_definition, cluster, overrides,
aws_conn_id=None, region_name=None, launch_type='EC2', **kwargs):
aws_conn_id=None, region_name=None, launch_type='EC2',
group=None, placement_constraints=None, platform_version='LATEST',
network_configuration=None, **kwargs):
super(ECSOperator, self).__init__(**kwargs)

self.aws_conn_id = aws_conn_id
Expand All @@ -63,6 +74,10 @@ def __init__(self, task_definition, cluster, overrides,
self.cluster = cluster
self.overrides = overrides
self.launch_type = launch_type
self.group = group
self.placement_constraints = placement_constraints
self.platform_version = platform_version
self.network_configuration = network_configuration

self.hook = self.get_hook()

Expand All @@ -78,13 +93,21 @@ def execute(self, context):
region_name=self.region_name
)

response = self.client.run_task(
cluster=self.cluster,
taskDefinition=self.task_definition,
overrides=self.overrides,
startedBy=self.owner,
launchType=self.launch_type
)
run_opts = {
'cluster': self.cluster,
'taskDefinition': self.task_definition,
'overrides': self.overrides,
'startedBy': self.owner,
'launchType': self.launch_type,
'platformVersion': self.platform_version,
}
if self.group is not None:
run_opts['group'] = self.group
if self.placement_constraints is not None:
run_opts['placementConstraints'] = self.placement_constraints
if self.network_configuration is not None:
run_opts['networkConfiguration'] = self.network_configuration
response = self.client.run_task(**run_opts)

failures = response['failures']
if len(failures) > 0:
Expand Down
45 changes: 42 additions & 3 deletions tests/contrib/operators/test_ecs_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,20 @@ def setUp(self, aws_hook_mock):
cluster='c',
overrides={},
aws_conn_id=None,
region_name='eu-west-1')
region_name='eu-west-1',
group='group',
placement_constraints=[
{
'expression': 'attribute:ecs.instance-type =~ t2.*',
'type': 'memberOf'
}
],
network_configuration={
'awsvpcConfiguration': {
'securityGroups': ['sg-123abc']
}
}
)

def test_init(self):

Expand Down Expand Up @@ -100,7 +113,20 @@ def test_execute_without_failures(self, check_mock, wait_mock):
launchType='EC2',
overrides={},
startedBy=mock.ANY, # Can by 'airflow' or 'Airflow'
taskDefinition='t'
taskDefinition='t',
group='group',
placementConstraints=[
{
'expression': 'attribute:ecs.instance-type =~ t2.*',
'type': 'memberOf'
}
],
platformVersion='LATEST',
networkConfiguration={
'awsvpcConfiguration': {
'securityGroups': ['sg-123abc']
}
}
)

wait_mock.assert_called_once_with()
Expand All @@ -123,7 +149,20 @@ def test_execute_with_failures(self):
launchType='EC2',
overrides={},
startedBy=mock.ANY, # Can by 'airflow' or 'Airflow'
taskDefinition='t'
taskDefinition='t',
group='group',
placementConstraints=[
{
'expression': 'attribute:ecs.instance-type =~ t2.*',
'type': 'memberOf'
}
],
platformVersion='LATEST',
networkConfiguration={
'awsvpcConfiguration': {
'securityGroups': ['sg-123abc']
}
}
)

def test_wait_end_tasks(self):
Expand Down