diff --git a/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py b/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py index f0fa97852a0d6..4f57b72d96aef 100644 --- a/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py +++ b/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py @@ -40,6 +40,7 @@ camelize_dict_keys, parse_assign_public_ip, ) +from airflow.providers.amazon.aws.hooks.ecs import EcsHook from airflow.utils.helpers import prune_dict @@ -60,6 +61,22 @@ def build_task_kwargs() -> dict: task_kwargs = _fetch_config_values() task_kwargs.update(_fetch_templated_kwargs()) + has_launch_type: bool = "launch_type" in task_kwargs + has_capacity_provider: bool = "capacity_provider_strategy" in task_kwargs + + if has_capacity_provider and has_launch_type: + raise ValueError( + "capacity_provider_strategy and launch_type are mutually exclusive, you can not provide both." + ) + elif "cluster" in task_kwargs and not (has_capacity_provider or has_launch_type): + # Default API behavior if neither is provided is to fall back on the default capacity + # provider if it exists. Since it is not a required value, check if there is one + # before using it, and if there is not then use the FARGATE launch_type as + # the final fallback. + cluster = EcsHook().conn.describe_clusters(clusters=[task_kwargs["cluster"]])["clusters"][0] + if not cluster.get("defaultCapacityProviderStrategy"): + task_kwargs["launch_type"] = "FARGATE" + # There can only be 1 count of these containers task_kwargs["count"] = 1 # type: ignore # There could be a generic approach to the below, but likely more convoluted then just manually ensuring diff --git a/airflow/providers/amazon/aws/executors/ecs/utils.py b/airflow/providers/amazon/aws/executors/ecs/utils.py index 4966fa3d2b8bc..7913bdf22719c 100644 --- a/airflow/providers/amazon/aws/executors/ecs/utils.py +++ b/airflow/providers/amazon/aws/executors/ecs/utils.py @@ -44,7 +44,6 @@ "conn_id": "aws_default", "max_run_task_attempts": "3", "assign_public_ip": "False", - "launch_type": "FARGATE", "platform_version": "LATEST", "check_health_on_startup": "True", } @@ -81,6 +80,7 @@ class RunTaskKwargsConfigKeys(BaseConfigKeys): """Keys loaded into the config which are valid ECS run_task kwargs.""" ASSIGN_PUBLIC_IP = "assign_public_ip" + CAPACITY_PROVIDER_STRATEGY = "capacity_provider_strategy" CLUSTER = "cluster" LAUNCH_TYPE = "launch_type" PLATFORM_VERSION = "platform_version" diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index d0f379e73df4f..09af54bcc9a6e 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -828,6 +828,19 @@ config: type: string example: "ecs_executor_cluster" default: ~ + capacity_provider_strategy: + description: | + The capacity provider strategy to use for the task. + + If a Capacity Provider Strategy is specified, the Launch Type parameter must be omitted. If + no Capacity Provider Strategy or Launch Type is specified, the Default CapacityProvider Strategy + for the cluster is used, if present. + + When you use cluster auto scaling, you must specify Capacity Provider Strategy and not Launch Type. + version_added: "8.17" + type: string + example: "[{'capacityProvider': 'cp1', 'weight': 5}, {'capacityProvider': 'cp2', 'weight': 1}]" + default: ~ container_name: description: | Name of the container that will be used to execute Airflow tasks via the ECS executor. @@ -843,6 +856,10 @@ config: Launch type can either be 'FARGATE' OR 'EC2'. For more info see url to Boto3 docs above. + If a Launch Type is specified, the Capacity Provider Strategy parameter must be omitted. If + no Capacity Provider Strategy or Launch Type is specified, the Default Capacity Provider Strategy + for the cluster is used, if present. + If the launch type is EC2, the executor will attempt to place tasks on empty EC2 instances. If there are no EC2 instances available, no task is placed and this function will be called again in the next heart-beat. @@ -852,7 +869,7 @@ config: version_added: "8.10" type: string example: "FARGATE" - default: "FARGATE" + default: ~ platform_version: description: | The platform version the task uses. A platform version is only specified diff --git a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py index 600f2597d3c34..78c3c1bc2811e 100644 --- a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py @@ -47,6 +47,7 @@ _recursive_flatten_dict, parse_assign_public_ip, ) +from airflow.providers.amazon.aws.hooks.ecs import EcsHook from airflow.utils.helpers import convert_camel_to_snake from airflow.utils.state import State, TaskInstanceState @@ -943,9 +944,9 @@ def test_provided_values_override_defaults(self, assign_subnets): assert task_kwargs["platformVersion"] == templated_version - def test_count_can_not_be_modified_by_the_user(self, assign_subnets): + @mock.patch.object(EcsHook, "conn") + def test_count_can_not_be_modified_by_the_user(self, _, assign_subnets): """The ``count`` parameter must always be 1; verify that the user can not override this value.""" - templated_version = "1" templated_cluster = "templated_cluster_name" provided_run_task_kwargs = { @@ -1086,3 +1087,63 @@ def test_start_health_check_config(self, set_env_vars): executor.start() ecs_mock.stop_task.assert_not_called() + + def test_providing_both_capacity_provider_and_launch_type_fails(self, set_env_vars): + os.environ[ + f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CAPACITY_PROVIDER_STRATEGY}".upper() + ] = "[{'capacityProvider': 'cp1', 'weight': 5}, {'capacityProvider': 'cp2', 'weight': 1}]" + expected_error = ( + "capacity_provider_strategy and launch_type are mutually exclusive, you can not provide both." + ) + + with pytest.raises(ValueError, match=expected_error): + AwsEcsExecutor() + + def test_providing_capacity_provider(self, set_env_vars): + # If a capacity provider strategy is supplied without a launch type, use the strategy. + + valid_capacity_provider = ( + "[{'capacityProvider': 'cp1', 'weight': 5}, {'capacityProvider': 'cp2', 'weight': 1}]" + ) + + os.environ[ + f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CAPACITY_PROVIDER_STRATEGY}".upper() + ] = valid_capacity_provider + os.environ.pop(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.LAUNCH_TYPE}".upper()) + + from airflow.providers.amazon.aws.executors.ecs import ecs_executor_config + + task_kwargs = ecs_executor_config.build_task_kwargs() + + assert "launchType" not in task_kwargs + assert task_kwargs["capacityProviderStrategy"] == valid_capacity_provider + + @mock.patch.object(EcsHook, "conn") + def test_providing_no_capacity_provider_no_lunch_type_with_cluster_default(self, mock_conn, set_env_vars): + # If no capacity provider strategy is supplied and no launch type, but the + # cluster has a default capacity provider strategy, use the cluster's default. + mock_conn.describe_clusters.return_value = { + "clusters": [{"defaultCapacityProviderStrategy": ["some_strategy"]}] + } + os.environ.pop(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.LAUNCH_TYPE}".upper()) + + from airflow.providers.amazon.aws.executors.ecs import ecs_executor_config + + task_kwargs = ecs_executor_config.build_task_kwargs() + assert "launchType" not in task_kwargs + assert "capacityProviderStrategy" not in task_kwargs + assert mock_conn.describe_clusters.called_once() + + @mock.patch.object(EcsHook, "conn") + def test_providing_no_capacity_provider_no_lunch_type_no_cluster_default(self, mock_conn, set_env_vars): + # If no capacity provider strategy is supplied and no launch type, and the cluster + # does not have a default capacity provider strategy, use the FARGATE launch type. + + mock_conn.describe_clusters.return_value = {"clusters": [{"status": "ACTIVE"}]} + + os.environ.pop(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.LAUNCH_TYPE}".upper()) + + from airflow.providers.amazon.aws.executors.ecs import ecs_executor_config + + task_kwargs = ecs_executor_config.build_task_kwargs() + assert task_kwargs["launchType"] == "FARGATE"