From 6ab5de6024626ba86a38298f9129b331a07988a3 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 9 Feb 2022 15:46:41 -0800 Subject: [PATCH 01/19] Update EKS sample DAGs to new standards --- .../aws/example_dags/example_eks_templated.py | 44 ++++++++++--------- .../example_eks_with_fargate_in_one_step.py | 14 +++--- .../example_eks_with_fargate_profile.py | 16 +++++-- .../example_eks_with_nodegroup_in_one_step.py | 16 ++++--- .../example_eks_with_nodegroups.py | 21 ++++++--- 5 files changed, 71 insertions(+), 40 deletions(-) diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_templated.py b/airflow/providers/amazon/aws/example_dags/example_eks_templated.py index afd19912aa1a9..26e10d7e5f4ce 100644 --- a/airflow/providers/amazon/aws/example_dags/example_eks_templated.py +++ b/airflow/providers/amazon/aws/example_dags/example_eks_templated.py @@ -14,11 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# Ignore missing args provided by default_args -# type: ignore[call-arg] - -import os +import json from datetime import datetime from airflow.models.dag import DAG @@ -37,7 +33,6 @@ { "cluster_name": "templated-cluster", "cluster_role_arn": "arn:aws:iam::123456789012:role/role_name", - "nodegroup_subnets": ["subnet-12345ab", "subnet-67890cd"], "resources_vpc_config": { "subnetIds": ["subnet-12345ab", "subnet-67890cd"], "endpointPublicAccess": true, @@ -49,25 +44,24 @@ """ with DAG( - dag_id='to-publish-manuals-templated', - default_args={'cluster_name': "{{ dag_run.conf['cluster_name'] }}"}, + dag_id='example_eks_templated', schedule_interval=None, start_date=datetime(2021, 1, 1), - catchup=False, - max_active_runs=1, tags=['example', 'templated'], + catchup=False, # render_template_as_native_obj=True is what converts the Jinja to Python objects, instead of a string. render_template_as_native_obj=True, ) as dag: - SUBNETS = os.environ.get('EKS_DEMO_SUBNETS', 'subnet-12345ab subnet-67890cd').split(' ') - VPC_CONFIG = { - 'subnetIds': SUBNETS, - 'endpointPublicAccess': True, - 'endpointPrivateAccess': False, - } + + CLUSTER_NAME = "{{ dag_run.conf['cluster_name'] }}" + NODEGROUP_NAME = "{{ dag_run.conf['nodegroup_name'] }}" + VPC_CONFIG = json.loads("{{ dag_run.conf['resources_vpc_config'] }}") + SUBNETS = VPC_CONFIG['subnetIds'] + # Create an Amazon EKS Cluster control plane without attaching a compute service. create_cluster = EksCreateClusterOperator( task_id='create_eks_cluster', + cluster_name=CLUSTER_NAME, compute=None, cluster_role_arn="{{ dag_run.conf['cluster_role_arn'] }}", resources_vpc_config=VPC_CONFIG, @@ -75,24 +69,28 @@ await_create_cluster = EksClusterStateSensor( task_id='wait_for_create_cluster', + cluster_name=CLUSTER_NAME, target_state=ClusterStates.ACTIVE, ) create_nodegroup = EksCreateNodegroupOperator( task_id='create_eks_nodegroup', - nodegroup_name="{{ dag_run.conf['nodegroup_name'] }}", - nodegroup_subnets="{{ dag_run.conf['nodegroup_subnets'] }}", + cluster_name=CLUSTER_NAME, + nodegroup_name=NODEGROUP_NAME, + nodegroup_subnets=SUBNETS, nodegroup_role_arn="{{ dag_run.conf['nodegroup_role_arn'] }}", ) await_create_nodegroup = EksNodegroupStateSensor( task_id='wait_for_create_nodegroup', - nodegroup_name="{{ dag_run.conf['nodegroup_name'] }}", + cluster_name=CLUSTER_NAME, + nodegroup_name=NODEGROUP_NAME, target_state=NodegroupStates.ACTIVE, ) start_pod = EksPodOperator( task_id="run_pod", + cluster_name=CLUSTER_NAME, pod_name="run_pod", image="amazon/aws-cli:latest", cmds=["sh", "-c", "ls"], @@ -104,21 +102,25 @@ delete_nodegroup = EksDeleteNodegroupOperator( task_id='delete_eks_nodegroup', - nodegroup_name="{{ dag_run.conf['nodegroup_name'] }}", + cluster_name=CLUSTER_NAME, + nodegroup_name=NODEGROUP_NAME, ) await_delete_nodegroup = EksNodegroupStateSensor( task_id='wait_for_delete_nodegroup', - nodegroup_name="{{ dag_run.conf['nodegroup_name'] }}", + cluster_name=CLUSTER_NAME, + nodegroup_name=NODEGROUP_NAME, target_state=NodegroupStates.NONEXISTENT, ) delete_cluster = EksDeleteClusterOperator( task_id='delete_eks_cluster', + cluster_name=CLUSTER_NAME, ) await_delete_cluster = EksClusterStateSensor( task_id='wait_for_delete_cluster', + cluster_name=CLUSTER_NAME, target_state=ClusterStates.NONEXISTENT, ) diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py index 4107058b5ac20..e08e6525e6fc8 100644 --- a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py +++ b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py @@ -43,19 +43,18 @@ with DAG( - dag_id='example-create-cluster-and-fargate-all-in-one', - default_args={'cluster_name': CLUSTER_NAME}, + dag_id='example_eks_with_fargate_in_one_step', schedule_interval=None, start_date=datetime(2021, 1, 1), - catchup=False, - max_active_runs=1, tags=['example'], + catchup=False, ) as dag: # [START howto_operator_eks_create_cluster_with_fargate_profile] # Create an Amazon EKS cluster control plane and an AWS Fargate compute platform in one step. create_cluster_and_fargate_profile = EksCreateClusterOperator( task_id='create_eks_cluster_and_fargate_profile', + cluster_name=CLUSTER_NAME, cluster_role_arn=ROLE_ARN, resources_vpc_config=VPC_CONFIG, compute='fargate', @@ -68,6 +67,7 @@ await_create_fargate_profile = EksFargateProfileStateSensor( task_id='wait_for_create_fargate_profile', + cluster_name=CLUSTER_NAME, fargate_profile_name=FARGATE_PROFILE_NAME, target_state=FargateProfileStates.ACTIVE, ) @@ -75,6 +75,7 @@ start_pod = EksPodOperator( task_id="run_pod", pod_name="run_pod", + cluster_name=CLUSTER_NAME, image="amazon/aws-cli:latest", cmds=["sh", "-c", "echo Test Airflow; date"], labels={"demo": "hello_world"}, @@ -86,11 +87,14 @@ # An Amazon EKS cluster can not be deleted with attached resources such as nodegroups or Fargate profiles. # Setting the `force` to `True` will delete any attached resources before deleting the cluster. delete_all = EksDeleteClusterOperator( - task_id='delete_fargate_profile_and_cluster', force_delete_compute=True + task_id='delete_fargate_profile_and_cluster', + cluster_name=CLUSTER_NAME, + force_delete_compute=True, ) await_delete_cluster = EksClusterStateSensor( task_id='wait_for_delete_cluster', + cluster_name=CLUSTER_NAME, target_state=ClusterStates.NONEXISTENT, ) diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py index e58e2de729b62..0724a9caac534 100644 --- a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py +++ b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py @@ -46,18 +46,17 @@ with DAG( - dag_id='example_eks_with_fargate_profile_dag', - default_args={'cluster_name': CLUSTER_NAME}, + dag_id='example_eks_with_fargate_profile', schedule_interval=None, start_date=datetime(2021, 1, 1), - catchup=False, - max_active_runs=1, tags=['example'], + catchup=False, ) as dag: # Create an Amazon EKS Cluster control plane without attaching a compute service. create_cluster = EksCreateClusterOperator( task_id='create_eks_cluster', + cluster_name=CLUSTER_NAME, cluster_role_arn=ROLE_ARN, resources_vpc_config=VPC_CONFIG, compute=None, @@ -65,26 +64,32 @@ await_create_cluster = EksClusterStateSensor( task_id='wait_for_create_cluster', + cluster_name=CLUSTER_NAME, target_state=ClusterStates.ACTIVE, ) # [START howto_operator_eks_create_fargate_profile] create_fargate_profile = EksCreateFargateProfileOperator( task_id='create_eks_fargate_profile', + cluster_name=CLUSTER_NAME, pod_execution_role_arn=ROLE_ARN, fargate_profile_name=FARGATE_PROFILE_NAME, selectors=SELECTORS, ) # [END howto_operator_eks_create_fargate_profile] + # [START howto_sensor_eks_fargate] await_create_fargate_profile = EksFargateProfileStateSensor( task_id='wait_for_create_fargate_profile', + cluster_name=CLUSTER_NAME, fargate_profile_name=FARGATE_PROFILE_NAME, target_state=FargateProfileStates.ACTIVE, ) + # [END howto_sensor_eks_fargate] start_pod = EksPodOperator( task_id="run_pod", + cluster_name=CLUSTER_NAME, pod_name="run_pod", image="amazon/aws-cli:latest", cmds=["sh", "-c", "echo Test Airflow; date"], @@ -97,12 +102,14 @@ # [START howto_operator_eks_delete_fargate_profile] delete_fargate_profile = EksDeleteFargateProfileOperator( task_id='delete_eks_fargate_profile', + cluster_name=CLUSTER_NAME, fargate_profile_name=FARGATE_PROFILE_NAME, ) # [END howto_operator_eks_delete_fargate_profile] await_delete_fargate_profile = EksFargateProfileStateSensor( task_id='wait_for_delete_fargate_profile', + cluster_name=CLUSTER_NAME, fargate_profile_name=FARGATE_PROFILE_NAME, target_state=FargateProfileStates.NONEXISTENT, ) @@ -111,6 +118,7 @@ await_delete_cluster = EksClusterStateSensor( task_id='wait_for_delete_cluster', + cluster_name=CLUSTER_NAME, target_state=ClusterStates.NONEXISTENT, ) diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py index f19eec622f295..38d1bd1ad4c2f 100644 --- a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py +++ b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py @@ -42,19 +42,18 @@ with DAG( - dag_id='example_eks_using_defaults_dag', - default_args={'cluster_name': CLUSTER_NAME}, + dag_id='example_eks_with_nodegroup_in_one_step', schedule_interval=None, start_date=datetime(2021, 1, 1), - catchup=False, - max_active_runs=1, tags=['example'], + catchup=False, ) as dag: # [START howto_operator_eks_create_cluster_with_nodegroup] # Create an Amazon EKS cluster control plane and an EKS nodegroup compute platform in one step. create_cluster_and_nodegroup = EksCreateClusterOperator( task_id='create_eks_cluster_and_nodegroup', + cluster_name=CLUSTER_NAME, nodegroup_name=NODEGROUP_NAME, cluster_role_arn=ROLE_ARN, nodegroup_role_arn=ROLE_ARN, @@ -68,12 +67,14 @@ await_create_nodegroup = EksNodegroupStateSensor( task_id='wait_for_create_nodegroup', + cluster_name=CLUSTER_NAME, nodegroup_name=NODEGROUP_NAME, target_state=NodegroupStates.ACTIVE, ) start_pod = EksPodOperator( task_id="run_pod", + cluster_name=CLUSTER_NAME, pod_name="run_pod", image="amazon/aws-cli:latest", cmds=["sh", "-c", "echo Test Airflow; date"], @@ -86,11 +87,16 @@ # [START howto_operator_eks_force_delete_cluster] # An Amazon EKS cluster can not be deleted with attached resources such as nodegroups or Fargate profiles. # Setting the `force` to `True` will delete any attached resources before deleting the cluster. - delete_all = EksDeleteClusterOperator(task_id='delete_nodegroup_and_cluster', force_delete_compute=True) + delete_all = EksDeleteClusterOperator( + task_id='delete_nodegroup_and_cluster', + cluster_name=CLUSTER_NAME, + force_delete_compute=True, + ) # [END howto_operator_eks_force_delete_cluster] await_delete_cluster = EksClusterStateSensor( task_id='wait_for_delete_cluster', + cluster_name=CLUSTER_NAME, target_state=ClusterStates.NONEXISTENT, ) diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py index 3ec6a3ac459a0..305f6aef6adbd 100644 --- a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py +++ b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py @@ -45,48 +45,55 @@ with DAG( - dag_id='example_eks_with_nodegroups_dag', - default_args={'cluster_name': CLUSTER_NAME}, + dag_id='example_eks_with_nodegroups', schedule_interval=None, start_date=datetime(2021, 1, 1), - catchup=False, - max_active_runs=1, tags=['example'], + catchup=False, ) as dag: # [START howto_operator_eks_create_cluster] # Create an Amazon EKS Cluster control plane without attaching compute service. create_cluster = EksCreateClusterOperator( task_id='create_eks_cluster', + cluster_name=CLUSTER_NAME, cluster_role_arn=ROLE_ARN, resources_vpc_config=VPC_CONFIG, compute=None, ) # [END howto_operator_eks_create_cluster] + # [START howto_sensor_eks_cluster] await_create_cluster = EksClusterStateSensor( task_id='wait_for_create_cluster', + cluster_name=CLUSTER_NAME, target_state=ClusterStates.ACTIVE, ) + # [END howto_sensor_eks_cluster] # [START howto_operator_eks_create_nodegroup] create_nodegroup = EksCreateNodegroupOperator( task_id='create_eks_nodegroup', + cluster_name=CLUSTER_NAME, nodegroup_name=NODEGROUP_NAME, nodegroup_subnets=SUBNETS, nodegroup_role_arn=ROLE_ARN, ) # [END howto_operator_eks_create_nodegroup] + # [START howto_sensor_eks_nodegroup] await_create_nodegroup = EksNodegroupStateSensor( task_id='wait_for_create_nodegroup', + cluster_name=CLUSTER_NAME, nodegroup_name=NODEGROUP_NAME, target_state=NodegroupStates.ACTIVE, ) + # [END howto_sensor_eks_nodegroup] # [START howto_operator_eks_pod_operator] start_pod = EksPodOperator( task_id="run_pod", + cluster_name=CLUSTER_NAME, pod_name="run_pod", image="amazon/aws-cli:latest", cmds=["sh", "-c", "ls"], @@ -99,12 +106,15 @@ # [START howto_operator_eks_delete_nodegroup] delete_nodegroup = EksDeleteNodegroupOperator( - task_id='delete_eks_nodegroup', nodegroup_name=NODEGROUP_NAME + task_id='delete_eks_nodegroup', + cluster_name=CLUSTER_NAME, + nodegroup_name=NODEGROUP_NAME, ) # [END howto_operator_eks_delete_nodegroup] await_delete_nodegroup = EksNodegroupStateSensor( task_id='wait_for_delete_nodegroup', + cluster_name=CLUSTER_NAME, nodegroup_name=NODEGROUP_NAME, target_state=NodegroupStates.NONEXISTENT, ) @@ -115,6 +125,7 @@ await_delete_cluster = EksClusterStateSensor( task_id='wait_for_delete_cluster', + cluster_name=CLUSTER_NAME, target_state=ClusterStates.NONEXISTENT, ) From f6cb933ef88479e0f5159cc892f9b91f69394b68 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 9 Feb 2022 16:52:13 -0800 Subject: [PATCH 02/19] Update EKS docs to new standards --- airflow/providers/amazon/aws/sensors/eks.py | 12 +++++ .../operators/eks.rst | 54 +++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/airflow/providers/amazon/aws/sensors/eks.py b/airflow/providers/amazon/aws/sensors/eks.py index 7f639b684103f..92ed55da4d31e 100644 --- a/airflow/providers/amazon/aws/sensors/eks.py +++ b/airflow/providers/amazon/aws/sensors/eks.py @@ -60,6 +60,10 @@ class EksClusterStateSensor(BaseSensorOperator): """ Check the state of an Amazon EKS Cluster until it reaches the target state or another terminal state. + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:EksClusterStateSensor` + :param cluster_name: The name of the Cluster to watch. (templated) :param target_state: Target state of the Cluster. (templated) :param region: Which AWS region the connection should use. (templated) @@ -116,6 +120,10 @@ class EksFargateProfileStateSensor(BaseSensorOperator): """ Check the state of an AWS Fargate profile until it reaches the target state or another terminal state. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/sensor:EksFargateProfileStateSensor` + :param cluster_name: The name of the Cluster which the AWS Fargate profile is attached to. (templated) :param fargate_profile_name: The name of the Fargate profile to watch. (templated) :param target_state: Target state of the Fargate profile. (templated) @@ -183,6 +191,10 @@ class EksNodegroupStateSensor(BaseSensorOperator): """ Check the state of an EKS managed node group until it reaches the target state or another terminal state. + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:EksNodegroupStateSensor` + :param cluster_name: The name of the Cluster which the Nodegroup is attached to. (templated) :param nodegroup_name: The name of the Nodegroup to watch. (templated) :param target_state: Target state of the Nodegroup. (templated) diff --git a/docs/apache-airflow-providers-amazon/operators/eks.rst b/docs/apache-airflow-providers-amazon/operators/eks.rst index e2b856de91dd7..d6cd5ad791973 100644 --- a/docs/apache-airflow-providers-amazon/operators/eks.rst +++ b/docs/apache-airflow-providers-amazon/operators/eks.rst @@ -34,6 +34,21 @@ Prerequisite Tasks Manage Amazon EKS Clusters ^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. _howto/sensor:EksClusterStateSensor: + +Amazon EKS Cluster State Sensor +""""""""""""""""""""""""""""""" + +To check the state of an Amazon EKS Cluster until it reaches the target state or another terminal +state you can use :class:`~airflow.providers.amazon.aws.sensors.eks.EksClusterStateSensor`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_eks_cluster] + :end-before: [END howto_sensor_eks_cluster] + + .. _howto/operator:EksCreateClusterOperator: Create an Amazon EKS Cluster @@ -48,6 +63,7 @@ Note: An AWS IAM role with the following permissions is required: .. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py :language: python + :dedent: 4 :start-after: [START howto_operator_eks_create_cluster] :end-before: [END howto_operator_eks_create_cluster] @@ -61,6 +77,7 @@ To delete an existing Amazon EKS Cluster you can use .. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py :language: python + :dedent: 4 :start-after: [START howto_operator_eks_delete_cluster] :end-before: [END howto_operator_eks_delete_cluster] @@ -70,6 +87,7 @@ Note: If the cluster has any attached resources, such as an Amazon EKS Nodegroup .. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py :language: python + :dedent: 4 :start-after: [START howto_operator_eks_force_delete_cluster] :end-before: [END howto_operator_eks_force_delete_cluster] @@ -77,6 +95,20 @@ Note: If the cluster has any attached resources, such as an Amazon EKS Nodegroup Manage Amazon EKS Managed Nodegroups ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. _howto/sensor:EksNodegroupStateSensor: + +Amazon EKS Managed Nodegroup State Sensor +""""""""""""""""""""""""""""""""""""""""" + +To check the state of an Amazon EKS managed node group until it reaches the target state or another terminal +state you can use :class:`~airflow.providers.amazon.aws.sensors.eks.EksNodegroupStateSensor`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_eks_nodegroup] + :end-before: [END howto_sensor_eks_nodegroup] + .. _howto/operator:EksCreateNodegroupOperator: Create an Amazon EKS Managed NodeGroup @@ -92,6 +124,7 @@ Note: An AWS IAM role with the following permissions is required: .. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py :language: python + :dedent: 4 :start-after: [START howto_operator_eks_create_nodegroup] :end-before: [END howto_operator_eks_create_nodegroup] @@ -105,6 +138,7 @@ To delete an existing Amazon EKS Managed Nodegroup you can use .. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py :language: python + :dedent: 4 :start-after: [START howto_operator_eks_delete_nodegroup] :end-before: [END howto_operator_eks_delete_nodegroup] @@ -124,6 +158,7 @@ Note: An AWS IAM role with the following permissions is required: .. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py :language: python + :dedent: 4 :start-after: [START howto_operator_eks_create_cluster_with_nodegroup] :end-before: [END howto_operator_eks_create_cluster_with_nodegroup] @@ -142,12 +177,28 @@ Note: An AWS IAM role with the following permissions is required: .. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py :language: python + :dedent: 4 :start-after: [START howto_operator_eks_create_cluster_with_fargate_profile] :end-before: [END howto_operator_eks_create_cluster_with_fargate_profile] Manage AWS Fargate Profiles ^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. _howto/sensor:EksFargateProfileStateSensor: + +AWS Fargate Profile State Sensor +"""""""""""""""""""""""""""""""" + +To check the state of an AWS Fargate profile until it reaches the target state or another terminal +state you can use :class:`~airflow.providers.amazon.aws.sensors.eks.EksFargateProfileSensor`. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_eks_fargate] + :end-before: [END howto_sensor_eks_fargate] + + .. _howto/operator:EksCreateFargateProfileOperator: Create an AWS Fargate Profile @@ -163,6 +214,7 @@ Note: An AWS IAM role with the following permissions is required: .. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py :language: python + :dedent: 4 :start-after: [START howto_operator_eks_create_fargate_profile] :end-before: [END howto_operator_eks_create_fargate_profile] @@ -176,6 +228,7 @@ To delete an existing AWS Fargate Profile you can use .. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py :language: python + :dedent: 4 :start-after: [START howto_operator_eks_delete_fargate_profile] :end-before: [END howto_operator_eks_delete_fargate_profile] @@ -191,6 +244,7 @@ Note: An Amazon EKS Cluster with underlying compute infrastructure is required. .. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py :language: python + :dedent: 4 :start-after: [START howto_operator_eks_pod_operator] :end-before: [END howto_operator_eks_pod_operator] From 009807064b8d09b8ea2126a9a4afc18fa88a3f56 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 9 Feb 2022 13:21:39 -0800 Subject: [PATCH 03/19] Allow blank lines after docstring (#21477) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 259303c3073e8..a73abec8487f8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -232,7 +232,7 @@ repos: name: Run pydocstyle args: - --convention=pep257 - - --add-ignore=D100,D102,D103,D104,D105,D107,D205,D400,D401 + - --add-ignore=D100,D102,D103,D104,D105,D107,D202,D205,D400,D401 exclude: | (?x) ^tests/.*\.py$| From c922e2bc122a1d57224dc75ebcde3c5ae3388518 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 9 Feb 2022 15:37:17 -0800 Subject: [PATCH 04/19] Fix typing of operator attrs for mypy (#21480) --- airflow/models/baseoperator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 34c84128391dc..8f961539d9813 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -671,7 +671,7 @@ def __init__( ) self.trigger_rule = trigger_rule - self.depends_on_past = depends_on_past + self.depends_on_past: bool = depends_on_past self.wait_for_downstream = wait_for_downstream if wait_for_downstream: self.depends_on_past = True @@ -714,7 +714,7 @@ def __init__( stacklevel=2, ) max_active_tis_per_dag = task_concurrency - self.max_active_tis_per_dag = max_active_tis_per_dag + self.max_active_tis_per_dag: Optional[int] = max_active_tis_per_dag self.do_xcom_push = do_xcom_push self.doc_md = doc_md From b617496f80467977438270a801f9571c14c5f19e Mon Sep 17 00:00:00 2001 From: "D. Ferruzzi" Date: Wed, 9 Feb 2022 20:36:34 -0800 Subject: [PATCH 05/19] Added SNS example DAG and rst (#21475) --- .../amazon/aws/example_dags/example_sns.py | 39 ++++++++++++ airflow/providers/amazon/aws/operators/sns.py | 4 ++ airflow/providers/amazon/provider.yaml | 2 + .../operators/sns.rst | 59 +++++++++++++++++++ 4 files changed, 104 insertions(+) create mode 100644 airflow/providers/amazon/aws/example_dags/example_sns.py create mode 100644 docs/apache-airflow-providers-amazon/operators/sns.rst diff --git a/airflow/providers/amazon/aws/example_dags/example_sns.py b/airflow/providers/amazon/aws/example_dags/example_sns.py new file mode 100644 index 0000000000000..782156b14c3d3 --- /dev/null +++ b/airflow/providers/amazon/aws/example_dags/example_sns.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import datetime +from os import environ + +from airflow import DAG +from airflow.providers.amazon.aws.operators.sns import SnsPublishOperator + +SNS_TOPIC_ARN = environ.get('SNS_TOPIC_ARN', 'arn:aws:sns:us-west-2:123456789012:dummy-topic-name') + +with DAG( + dag_id='example_sns', + schedule_interval=None, + start_date=datetime(2021, 1, 1), + tags=['example'], + catchup=False, +) as dag: + + # [START howto_operator_sns_publish_operator] + publish = SnsPublishOperator( + task_id='publish_message', + target_arn=SNS_TOPIC_ARN, + message='This is a sample message sent to SNS via an Apache Airflow DAG task.', + ) + # [END howto_operator_sns_publish_operator] diff --git a/airflow/providers/amazon/aws/operators/sns.py b/airflow/providers/amazon/aws/operators/sns.py index 48a436b020745..e916798d03386 100644 --- a/airflow/providers/amazon/aws/operators/sns.py +++ b/airflow/providers/amazon/aws/operators/sns.py @@ -30,6 +30,10 @@ class SnsPublishOperator(BaseOperator): """ Publish a message to Amazon SNS. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SnsPublishOperator` + :param aws_conn_id: aws connection to use :param target_arn: either a TopicArn or an EndpointArn :param message: the default message you want to send (templated) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index f6887e0477e6c..962e48b10541c 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -127,6 +127,8 @@ integrations: - integration-name: Amazon Simple Notification Service (SNS) external-doc-url: https://aws.amazon.com/sns/ logo: /integration-logos/aws/Amazon-Simple-Notification-Service-SNS_light-bg@4x.png + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/sns.rst tags: [aws] - integration-name: Amazon Simple Queue Service (SQS) external-doc-url: https://aws.amazon.com/sqs/ diff --git a/docs/apache-airflow-providers-amazon/operators/sns.rst b/docs/apache-airflow-providers-amazon/operators/sns.rst new file mode 100644 index 0000000000000..1853bf27ae9dc --- /dev/null +++ b/docs/apache-airflow-providers-amazon/operators/sns.rst @@ -0,0 +1,59 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Amazon Simple Notification Service (SNS) Operators +================================================== + +`Amazon Simple Notification Service (Amazon SNS) `__ is a managed +service that provides message delivery from publishers to subscribers (also known as producers +and consumers). Publishers communicate asynchronously with subscribers by sending messages to +a topic, which is a logical access point and communication channel. Clients can subscribe to the +SNS topic and receive published messages using a supported endpoint type, such as Amazon Kinesis +Data Firehose, Amazon SQS, AWS Lambda, HTTP, email, mobile push notifications, and mobile text +messages (SMS). + +Airflow provides an operator to publish messages to an SNS Topic. + +Prerequisite Tasks +^^^^^^^^^^^^^^^^^^ + +.. include::/operators/_partials/prerequisite_tasks.rst + + +.. _howto/operator:SnsPublishOperator: + +Publish A Message To An Existing SNS Topic +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To publish a message to an Amazon SNS Topic you can use +:class:`~airflow.providers.amazon.aws.operators.sns.SnsPublishOperator`. + + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sns.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sns_publish_operator] + :end-before: [END howto_operator_sns_publish_operator] + + +Reference +^^^^^^^^^ + +For further information, look at: + +* `Boto3 Library Documentation for SNS `__ From adeba230facaeedaed1e7eb35e925c384efe73a7 Mon Sep 17 00:00:00 2001 From: James Timmins Date: Wed, 9 Feb 2022 21:37:53 -0800 Subject: [PATCH 06/19] Simplify fab has access lookup (#21482) Co-authored-by: Ash Berlin-Taylor --- airflow/www/fab_security/sqla/models.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/airflow/www/fab_security/sqla/models.py b/airflow/www/fab_security/sqla/models.py index 69853722d59b7..93a95a45c5e21 100644 --- a/airflow/www/fab_security/sqla/models.py +++ b/airflow/www/fab_security/sqla/models.py @@ -37,7 +37,6 @@ ) from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import backref, relationship -from sqlalchemy.orm.relationships import foreign """ Compatibility note: The models in this file are duplicated from Flask AppBuilder. @@ -140,17 +139,13 @@ class Permission(Model): action_id = Column("permission_id", Integer, ForeignKey("ab_permission.id")) action = relationship( "Action", - primaryjoin=action_id == foreign(Action.id), uselist=False, - backref="permission", lazy="joined", ) resource_id = Column("view_menu_id", Integer, ForeignKey("ab_view_menu.id")) resource = relationship( "Resource", - primaryjoin=resource_id == foreign(Resource.id), uselist=False, - backref="permission", lazy="joined", ) From c13aa11a3d9578d145ecfc2ea20da4f163b03bbd Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 10 Feb 2022 15:07:50 +0800 Subject: [PATCH 07/19] Rewrite decorated task mapping (#21328) --- airflow/decorators/base.py | 92 +++++++++++++++---- airflow/models/baseoperator.py | 55 ++--------- airflow/models/taskinstance.py | 2 +- airflow/serialization/serialized_objects.py | 27 +++++- tests/dags/test_mapped_taskflow.py | 31 +++++++ tests/decorators/test_python.py | 69 +++++++++++--- tests/jobs/test_backfill_job.py | 16 +++- tests/serialization/test_dag_serialization.py | 53 +++++++++++ 8 files changed, 259 insertions(+), 86 deletions(-) create mode 100644 tests/dags/test_mapped_taskflow.py diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 9cf423fb69e11..53a12c62d2318 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -280,30 +280,88 @@ def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any], valid_names names = ", ".join(repr(n) for n in unknown_args) raise TypeError(f'{funcname} got unexpected keyword arguments {names}') - def map( - self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs - ) -> XComArg: + def map(self, *args, **kwargs) -> XComArg: self._validate_arg_names("map", kwargs) - dag = dag or DagContext.get_current_dag() - task_group = task_group or TaskGroupContext.get_current_task_group(dag) - task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group) - operator = MappedOperator.from_decorator( - decorator=self, + partial_kwargs = self.kwargs.copy() + dag = partial_kwargs.pop("dag", DagContext.get_current_dag()) + task_group = partial_kwargs.pop("task_group", TaskGroupContext.get_current_task_group(dag)) + task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group) + + # Unfortunately attrs's type hinting support does not work well with + # subclassing; it complains that arguments forwarded to the superclass + # are "unexpected" (they are fine at runtime). + operator = cast(Any, DecoratedMappedOperator)( + operator_class=self.operator_class, + partial_kwargs=partial_kwargs, + mapped_kwargs={}, + task_id=task_id, dag=dag, task_group=task_group, - task_id=task_id, - mapped_kwargs=kwargs, + deps=MappedOperator._deps(self.operator_class.deps), + multiple_outputs=self.multiple_outputs, + python_callable=self.function, ) + + operator.mapped_kwargs["op_args"] = list(args) + operator.mapped_kwargs["op_kwargs"] = kwargs + + for arg in itertools.chain(args, kwargs.values()): + XComArg.apply_upstream_relationship(operator, arg) return XComArg(operator=operator) - def partial( - self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs - ) -> "_TaskDecorator[Function, OperatorSubclass]": - self._validate_arg_names("partial", kwargs, {'task_id'}) - partial_kwargs = self.kwargs.copy() - partial_kwargs.update(kwargs) - return attr.evolve(self, kwargs=partial_kwargs) + def partial(self, *args, **kwargs) -> "_TaskDecorator[Function, OperatorSubclass]": + self._validate_arg_names("partial", kwargs) + + op_args = self.kwargs.get("op_args", []) + op_args.extend(args) + + op_kwargs = self.kwargs.get("op_kwargs", {}) + op_kwargs = _merge_kwargs(op_kwargs, kwargs, fail_reason="duplicate partial") + + return attr.evolve(self, kwargs={**self.kwargs, "op_args": op_args, "op_kwargs": op_kwargs}) + + +def _merge_kwargs( + kwargs1: Dict[str, XComArg], + kwargs2: Dict[str, XComArg], + *, + fail_reason: str, +) -> Dict[str, XComArg]: + duplicated_keys = set(kwargs1).intersection(kwargs2) + if len(duplicated_keys) == 1: + raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}") + elif duplicated_keys: + duplicated_keys_display = ", ".join(sorted(duplicated_keys)) + raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}") + return {**kwargs1, **kwargs2} + + +@attr.define(kw_only=True) +class DecoratedMappedOperator(MappedOperator): + """MappedOperator implementation for @task-decorated task function.""" + + multiple_outputs: bool + python_callable: Callable + + def create_unmapped_operator(self, dag: "DAG") -> BaseOperator: + assert not isinstance(self.operator_class, str) + op_args = self.partial_kwargs.pop("op_args", []) + self.mapped_kwargs.pop("op_args", []) + op_kwargs = _merge_kwargs( + self.partial_kwargs.pop("op_kwargs", {}), + self.mapped_kwargs.pop("op_kwargs", {}), + fail_reason="mapping already partial", + ) + return self.operator_class( + dag=dag, + task_id=self.task_id, + op_args=op_args, + op_kwargs=op_kwargs, + multiple_outputs=self.multiple_outputs, + python_callable=self.python_callable, + **self.partial_kwargs, + **self.mapped_kwargs, + ) class Task(Generic[Function]): diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 8f961539d9813..35a0fbb53fe8a 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -82,7 +82,6 @@ from airflow.utils.weight_rule import WeightRule if TYPE_CHECKING: - from airflow.decorators.base import _TaskDecorator from airflow.models.dag import DAG from airflow.utils.task_group import TaskGroup @@ -243,7 +242,7 @@ def __new__(cls, name, bases, namespace, **kwargs): return new_cls # The class level partial function. This is what handles the actual mapping - def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs): + def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs) -> "MappedOperator": operator_class = cast("Type[BaseOperator]", cls) # Validate that the args we passed are known -- at call/DAG parse time, not run time! _validate_kwarg_names_for_mapping(operator_class, "partial", kwargs) @@ -1632,7 +1631,7 @@ def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> dag._remove_task(operator.task_id) operator_init_kwargs: dict = operator._BaseOperator__init_kwargs # type: ignore - return MappedOperator( + return cls( operator_class=type(operator), task_id=operator.task_id, task_group=task_group, @@ -1648,37 +1647,6 @@ def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> deps=cls._deps(operator.deps), ) - @classmethod - def from_decorator( - cls, - *, - decorator: "_TaskDecorator", - dag: Optional["DAG"], - task_group: Optional["TaskGroup"], - task_id: str, - mapped_kwargs: Dict[str, Any], - ) -> "MappedOperator": - """Create a mapped operator from a task decorator. - - Different from ``from_operator``, this DOES NOT validate ``mapped_kwargs``. - The task decorator calling this should be responsible for validation. - """ - from airflow.models.xcom_arg import XComArg - - operator = MappedOperator( - operator_class=decorator.operator_class, - partial_kwargs=decorator.kwargs, - mapped_kwargs={}, - task_id=task_id, - dag=dag, - task_group=task_group, - deps=cls._deps(decorator.operator_class.deps), - ) - operator.mapped_kwargs.update(mapped_kwargs) - for arg in mapped_kwargs.values(): - XComArg.apply_upstream_relationship(operator, arg) - return operator - @classmethod def _deps(cls, deps: Iterable[BaseTIDep]): if deps is BaseOperator.deps: @@ -1749,7 +1717,7 @@ def inherits_from_dummy_operator(self): @classmethod def get_serialized_fields(cls): if cls.__serialized_fields is None: - fields_dict = attr.fields_dict(cls) + fields_dict = attr.fields_dict(MappedOperator) cls.__serialized_fields = frozenset( fields_dict.keys() - { @@ -1902,22 +1870,17 @@ def expand_mapped_task( return ret - def unmap(self) -> BaseOperator: - """Get the "normal" Operator after applying the current mapping""" + def create_unmapped_operator(self, dag: "DAG") -> BaseOperator: assert not isinstance(self.operator_class, str) + return self.operator_class(dag=dag, task_id=self.task_id, **self.partial_kwargs, **self.mapped_kwargs) + def unmap(self) -> BaseOperator: + """Get the "normal" Operator after applying the current mapping""" dag = self.get_dag() if not dag: - raise RuntimeError("Cannot unmapp a task unless it has a dag") - - args = { - **self.partial_kwargs, - **self.mapped_kwargs, - } + raise RuntimeError("Cannot unmap a task unless it has a DAG") dag._remove_task(self.task_id) - task = self.operator_class(task_id=self.task_id, dag=self.dag, **args) - - return task + return self.create_unmapped_operator(dag) # TODO: Deprecate for Airflow 3.0 diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 4996b9a7db073..f10032dfc2964 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1713,7 +1713,7 @@ def handle_failure( test_mode: Optional[bool] = None, force_fail: bool = False, error_file: Optional[str] = None, - session=NEW_SESSION, + session: Session = NEW_SESSION, ) -> None: """Handle Failure for the TaskInstance""" if test_mode is None: diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index d6abda7c74899..017f2276964ca 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -16,6 +16,7 @@ # under the License. """Serialized DAG and BaseOperator""" +import contextlib import datetime import enum import logging @@ -168,7 +169,7 @@ def _decode_timetable(var: Dict[str, Any]) -> Timetable: return timetable_class.deserialize(var[Encoding.VAR]) -class _XcomRef(NamedTuple): +class _XComRef(NamedTuple): """ Used to store info needed to create XComArg when deserializing MappedOperator. @@ -497,8 +498,8 @@ def _serialize_xcomarg(cls, arg: XComArg) -> dict: return {"key": arg.key, "task_id": arg.operator.task_id} @classmethod - def _deserialize_xcomref(cls, encoded: dict) -> _XcomRef: - return _XcomRef(key=encoded['key'], task_id=encoded['task_id']) + def _deserialize_xcomref(cls, encoded: dict) -> _XComRef: + return _XComRef(key=encoded['key'], task_id=encoded['task_id']) class DependencyDetector: @@ -566,9 +567,19 @@ def task_type(self, task_type: str): @classmethod def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]: - stock_deps = op.deps is MappedOperator.DEFAULT_DEPS serialize_op = cls._serialize_node(op, include_deps=not stock_deps) + + # Simplify op_kwargs format. It must be a dict, so we flatten it. + with contextlib.suppress(KeyError): + op_kwargs = serialize_op["mapped_kwargs"]["op_kwargs"] + assert op_kwargs[Encoding.TYPE] == DAT.DICT + serialize_op["mapped_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR] + with contextlib.suppress(KeyError): + op_kwargs = serialize_op["partial_kwargs"]["op_kwargs"] + assert op_kwargs[Encoding.TYPE] == DAT.DICT + serialize_op["partial_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR] + # It must be a class at this point for it to work, not a string assert isinstance(op.operator_class, type) serialize_op['_task_type'] = op.operator_class.__name__ @@ -715,7 +726,13 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Union[BaseOperator, elif k == "params": v = cls._deserialize_params_dict(v) elif k in ("mapped_kwargs", "partial_kwargs"): + if "op_kwargs" not in v: + op_kwargs: Optional[dict] = None + else: + op_kwargs = {arg: cls._deserialize(value) for arg, value in v.pop("op_kwargs").items()} v = {arg: cls._deserialize(value) for arg, value in v.items()} + if op_kwargs is not None: + v["op_kwargs"] = op_kwargs elif k in cls._decorated_fields or k not in op.get_serialized_fields(): v = cls._deserialize(v) # else use v as it is @@ -1002,7 +1019,7 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG': if isinstance(task, MappedOperator): for d in (task.mapped_kwargs, task.partial_kwargs): for k, v in d.items(): - if not isinstance(v, _XcomRef): + if not isinstance(v, _XComRef): continue d[k] = XComArg(operator=dag.get_task(v.task_id), key=v.key) diff --git a/tests/dags/test_mapped_taskflow.py b/tests/dags/test_mapped_taskflow.py new file mode 100644 index 0000000000000..f21a9a5e8a42d --- /dev/null +++ b/tests/dags/test_mapped_taskflow.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow import DAG +from airflow.utils.dates import days_ago + +with DAG(dag_id='test_mapped_taskflow', start_date=days_ago(2)) as dag: + + @dag.task + def make_list(): + return [1, 2, {'a': 'b'}] + + @dag.task + def consumer(value): + print(repr(value)) + + consumer.map(value=make_list()) diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 0c93b49e1fe00..ee94fde610d7a 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -17,7 +17,7 @@ # under the License. import sys from collections import namedtuple -from datetime import date, timedelta +from datetime import date, datetime, timedelta from typing import Dict # noqa: F401 # This is used by annotation tests. from typing import Tuple @@ -490,7 +490,7 @@ def double(number: int): assert isinstance(doubled_0, XComArg) assert isinstance(doubled_0.operator, MappedOperator) assert doubled_0.operator.task_id == "double" - assert doubled_0.operator.mapped_kwargs == {"number": literal} + assert doubled_0.operator.mapped_kwargs == {"op_args": [], "op_kwargs": {"number": literal}} assert doubled_1.operator.task_id == "double__1" @@ -514,25 +514,68 @@ def test_partial_mapped_decorator() -> None: def product(number: int, multiple: int): return number * multiple + literal = [1, 2, 3] + with DAG('test_dag', start_date=DEFAULT_DATE) as dag: - literal = [1, 2, 3] - quadrupled = product.partial(task_id='times_4', multiple=3).map(number=literal) + quadrupled = product.partial(multiple=3).map(number=literal) doubled = product.partial(multiple=2).map(number=literal) trippled = product.partial(multiple=3).map(number=literal) - product.partial(multiple=2) + product.partial(multiple=2) # No operator is actually created. + + assert dag.task_dict == { + "product": quadrupled.operator, + "product__1": doubled.operator, + "product__2": trippled.operator, + } assert isinstance(doubled, XComArg) assert isinstance(doubled.operator, MappedOperator) - assert doubled.operator.task_id == "product" - assert doubled.operator.mapped_kwargs == {"number": literal} - assert doubled.operator.partial_kwargs == {"task_id": "product", "multiple": 2} + assert doubled.operator.mapped_kwargs == {"op_args": [], "op_kwargs": {"number": literal}} + assert doubled.operator.partial_kwargs == {"op_args": [], "op_kwargs": {"multiple": 2}} - assert trippled.operator.task_id == "product__1" - assert trippled.operator.partial_kwargs == {"task_id": "product", "multiple": 3} - - assert quadrupled.operator.task_id == "times_4" + assert isinstance(trippled.operator, MappedOperator) # For type-checking on partial_kwargs. + assert trippled.operator.partial_kwargs == {"op_args": [], "op_kwargs": {"multiple": 3}} assert doubled.operator is not trippled.operator - assert [quadrupled.operator, doubled.operator, trippled.operator] == dag.tasks + +def test_mapped_decorator_unmap_merge_op_kwargs(): + with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag: + + @task_decorator + def task1(): + ... + + @task_decorator + def task2(arg1, arg2): + ... + + task2.partial(arg1=1).map(arg2=task1()) + + unmapped = dag.get_task("task2").unmap() + assert set(unmapped.op_kwargs) == {"arg1", "arg2"} + + +def test_mapped_decorator_unmap_converts_partial_kwargs(): + with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag: + + @task_decorator + def task1(arg): + ... + + @task_decorator(retry_delay=30) + def task2(arg1, arg2): + ... + + task2.partial(arg1=1).map(arg2=task1.map(arg=[1, 2])) + + # Arguments to the task decorator are stored in partial_kwargs, and + # converted into their intended form after the task is unmapped. + mapped_task2 = dag.get_task("task2") + assert mapped_task2.partial_kwargs["retry_delay"] == 30 + assert mapped_task2.unmap().retry_delay == timedelta(seconds=30) + + mapped_task1 = dag.get_task("task1") + assert "retry_delay" not in mapped_task1.partial_kwargs + mapped_task1.unmap().retry_delay == timedelta(seconds=300) # Operator default. diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index 0878f63ddffc6..40593d526a328 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -47,7 +47,13 @@ from airflow.utils.timeout import timeout from airflow.utils.types import DagRunType from tests.models import TEST_DAGS_FOLDER -from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, set_default_pool_slots +from tests.test_utils.db import ( + clear_db_dags, + clear_db_pools, + clear_db_runs, + clear_db_xcom, + set_default_pool_slots, +) from tests.test_utils.mock_executor import MockExecutor from tests.test_utils.timetables import cron_timetable @@ -66,6 +72,7 @@ class TestBackfillJob: def clean_db(): clear_db_dags() clear_db_runs() + clear_db_xcom() clear_db_pools() @pytest.fixture(autouse=True) @@ -1512,13 +1519,14 @@ def test_backfill_has_job_id(self): job.run() assert executor.job_id is not None - def test_mapped_dag(self, dag_maker): + @pytest.mark.parametrize("dag_id", ["test_mapped_classic", "test_mapped_taskflow"]) + def test_mapped_dag(self, dag_id): """End-to-end test of a simple mapped dag""" # Use SequentialExecutor for more predictable test behaviour from airflow.executors.sequential_executor import SequentialExecutor - self.dagbag.process_file(str(TEST_DAGS_FOLDER / 'test_mapped_classic.py')) - dag = self.dagbag.get_dag('test_mapped_classic') + self.dagbag.process_file(str(TEST_DAGS_FOLDER / f'{dag_id}.py')) + dag = self.dagbag.get_dag(dag_id) # This needs a real executor to run, so that the `make_list` task can write out the TaskMap diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 447b1732a78b0..1e8d510fd7205 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -1654,6 +1654,59 @@ def test_mapped_operator_xcomarg_serde(): assert xcom_arg.operator is serialized_dag.task_dict['op1'] +def test_mapped_decorator_serde(): + from airflow.decorators import task + from airflow.models.xcom_arg import XComArg + from airflow.serialization.serialized_objects import _XComRef + + with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag: + op1 = BaseOperator(task_id="op1") + xcomarg = XComArg(op1, "my_key") + + @task(retry_delay=30) + def x(arg1, arg2, arg3, arg4): + print(arg1, arg2, arg3, arg4) + + x.partial("foo", arg3=[1, 2, {"a": "b"}]).map({"a": 1, "b": 2}, arg4=xcomarg) + + original = dag.get_task("x") + + serialized = SerializedBaseOperator._serialize(original) + assert serialized == { + '_is_dummy': False, + '_is_mapped': True, + '_task_module': 'airflow.decorators.python', + '_task_type': '_PythonDecoratedOperator', + 'downstream_task_ids': [], + 'partial_kwargs': { + 'op_args': ["foo"], + 'op_kwargs': {'arg3': [1, 2, {"__type": "dict", "__var": {'a': 'b'}}]}, + 'retry_delay': 30, + }, + 'mapped_kwargs': { + 'op_args': [{"__type": "dict", "__var": {'a': 1, 'b': 2}}], + 'op_kwargs': {'arg4': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'my_key'}}}, + }, + 'task_id': 'x', + 'template_ext': [], + 'template_fields': ['op_args', 'op_kwargs'], + } + + deserialized = SerializedBaseOperator.deserialize_operator(serialized) + assert isinstance(deserialized, MappedOperator) + assert deserialized.deps is MappedOperator.DEFAULT_DEPS + + assert deserialized.mapped_kwargs == { + "op_args": [{"a": 1, "b": 2}], + "op_kwargs": {"arg4": _XComRef("op1", "my_key")}, + } + assert deserialized.partial_kwargs == { + "retry_delay": 30, + "op_args": ["foo"], + "op_kwargs": {"arg3": [1, 2, {"a": "b"}]}, + } + + def test_mapped_task_group_serde(): execution_date = datetime(2020, 1, 1) From ce071460ac028b818d9f7c07beb8d0bf52b3cae9 Mon Sep 17 00:00:00 2001 From: mhenc Date: Thu, 10 Feb 2022 10:41:59 +0100 Subject: [PATCH 08/19] Move Zombie detection to SchedulerJob (#21181) --- airflow/config_templates/config.yml | 7 + airflow/config_templates/default_airflow.cfg | 3 + airflow/dag_processing/manager.py | 58 +------- airflow/jobs/scheduler_job.py | 44 ++++++ airflow/models/dagbag.py | 1 - tests/dag_processing/test_manager.py | 149 +------------------ tests/jobs/test_scheduler_job.py | 113 +++++++++++++- 7 files changed, 168 insertions(+), 207 deletions(-) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 8ef738f4f4ad6..83ea6b94ccff5 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1850,6 +1850,13 @@ type: string example: ~ default: "300" + - name: zombie_detection_interval + description: | + How often (in seconds) should the scheduler check for zombie tasks. + version_added: 2.3.0 + type: float + example: ~ + default: "10.0" - name: catchup_by_default description: | Turn off scheduler catchup by setting this to ``False``. diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 520ab4442850d..55161a55d5e71 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -923,6 +923,9 @@ child_process_log_directory = {AIRFLOW_HOME}/logs/scheduler # associated task instance as failed and will re-schedule the task. scheduler_zombie_task_threshold = 300 +# How often (in seconds) should the scheduler check for zombie tasks. +zombie_detection_interval = 10.0 + # Turn off scheduler catchup by setting this to ``False``. # Default behavior is unchanged and # Command Line Backfills still work, but the scheduler diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index 3b8a998551aca..33b219ccd58ea 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -34,7 +34,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union, cast from setproctitle import setproctitle -from sqlalchemy import or_ from tabulate import tabulate import airflow.models @@ -42,17 +41,15 @@ from airflow.dag_processing.processor import DagFileProcessorProcess from airflow.models import DagModel, errors from airflow.models.serialized_dag import SerializedDagModel -from airflow.models.taskinstance import SimpleTaskInstance from airflow.stats import Stats from airflow.utils import timezone -from airflow.utils.callback_requests import CallbackRequest, SlaCallbackRequest, TaskCallbackRequest +from airflow.utils.callback_requests import CallbackRequest, SlaCallbackRequest from airflow.utils.file import list_py_file_paths, might_contain_dag from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.mixins import MultiprocessingStartMethodMixin from airflow.utils.net import get_hostname from airflow.utils.process_utils import kill_child_processes_by_pids, reap_process_group from airflow.utils.session import provide_session -from airflow.utils.state import State if TYPE_CHECKING: import pathlib @@ -434,8 +431,6 @@ def __init__( # How often to print out DAG file processing stats to the log. Default to # 30 seconds. self.print_stats_interval = conf.getint('scheduler', 'print_stats_interval') - # How many seconds do we wait for tasks to heartbeat before mark them as zombies. - self._zombie_threshold_secs = conf.getint('scheduler', 'scheduler_zombie_task_threshold') # Map from file path to the processor self._processors: Dict[str, DagFileProcessorProcess] = {} @@ -445,13 +440,10 @@ def __init__( # Map from file path to stats about the file self._file_stats: Dict[str, DagFileStat] = {} - self._last_zombie_query_time = None # Last time that the DAG dir was traversed to look for files self.last_dag_dir_refresh_time = timezone.make_aware(datetime.fromtimestamp(0)) # Last time stats were printed self.last_stat_print_time = 0 - # TODO: Remove magic number - self._zombie_query_interval = 10 # How long to wait before timing out a process to parse a DAG file self._processor_timeout = processor_timeout @@ -566,7 +558,6 @@ def _run_parsing_loop(self): self._processors.pop(processor.file_path) self._refresh_dag_dir() - self._find_zombies() self._kill_timed_out_processors() @@ -1023,53 +1014,6 @@ def prepare_file_path_queue(self): self._file_path_queue.extend(files_paths_to_queue) - @provide_session - def _find_zombies(self, session): - """ - Find zombie task instances, which are tasks haven't heartbeated for too long - and update the current zombie list. - """ - now = timezone.utcnow() - if ( - not self._last_zombie_query_time - or (now - self._last_zombie_query_time).total_seconds() > self._zombie_query_interval - ): - # to avoid circular imports - from airflow.jobs.local_task_job import LocalTaskJob as LJ - - self.log.info("Finding 'running' jobs without a recent heartbeat") - TI = airflow.models.TaskInstance - DM = airflow.models.DagModel - limit_dttm = timezone.utcnow() - timedelta(seconds=self._zombie_threshold_secs) - - zombies = ( - session.query(TI, DM.fileloc) - .join(LJ, TI.job_id == LJ.id) - .join(DM, TI.dag_id == DM.dag_id) - .filter(TI.state == State.RUNNING) - .filter( - or_( - LJ.state != State.RUNNING, - LJ.latest_heartbeat < limit_dttm, - ) - ) - .all() - ) - - if zombies: - self.log.warning("Failing (%s) jobs without heartbeat after %s", len(zombies), limit_dttm) - - self._last_zombie_query_time = timezone.utcnow() - for ti, file_loc in zombies: - request = TaskCallbackRequest( - full_filepath=file_loc, - simple_task_instance=SimpleTaskInstance(ti), - msg=f"Detected {ti} as zombie", - ) - self.log.error("Detected zombie job: %s", request) - self._add_callback_to_queue(request) - Stats.incr('zombies_killed') - def _kill_timed_out_processors(self): """Kill any file processors that timeout to defend against process hangs.""" now = timezone.utcnow() diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 7a6e3efd74ba9..62116553c675e 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -38,6 +38,7 @@ from airflow.dag_processing.manager import DagFileProcessorAgent from airflow.executors.executor_loader import UNPICKLEABLE_EXECUTORS from airflow.jobs.base_job import BaseJob +from airflow.jobs.local_task_job import LocalTaskJob from airflow.models import DAG from airflow.models.dag import DagModel from airflow.models.dagbag import DagBag @@ -123,6 +124,8 @@ def __init__( ) scheduler_idle_sleep_time = processor_poll_interval self._scheduler_idle_sleep_time = scheduler_idle_sleep_time + # How many seconds do we wait for tasks to heartbeat before mark them as zombies. + self._zombie_threshold_secs = conf.getint('scheduler', 'scheduler_zombie_task_threshold') self.do_pickle = do_pickle super().__init__(*args, **kwargs) @@ -739,6 +742,11 @@ def _run_scheduler_loop(self) -> None: self._emit_pool_metrics, ) + timers.call_regular_interval( + conf.getfloat('scheduler', 'zombie_detection_interval', fallback=10.0), + self._find_zombies, + ) + for loop_count in itertools.count(start=1): with Stats.timer() as timer: @@ -1259,3 +1267,39 @@ def check_trigger_timeouts(self, session: Session = None): ) if num_timed_out_tasks: self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks) + + @provide_session + def _find_zombies(self, session): + """ + Find zombie task instances, which are tasks haven't heartbeated for too long + and update the current zombie list. + """ + self.log.debug("Finding 'running' jobs without a recent heartbeat") + limit_dttm = timezone.utcnow() - timedelta(seconds=self._zombie_threshold_secs) + + zombies = ( + session.query(TaskInstance, DagModel.fileloc) + .join(LocalTaskJob, TaskInstance.job_id == LocalTaskJob.id) + .join(DagModel, TaskInstance.dag_id == DagModel.dag_id) + .filter(TaskInstance.state == State.RUNNING) + .filter( + or_( + LocalTaskJob.state != State.RUNNING, + LocalTaskJob.latest_heartbeat < limit_dttm, + ) + ) + .all() + ) + + if zombies: + self.log.warning("Failing (%s) jobs without heartbeat after %s", len(zombies), limit_dttm) + + for ti, file_loc in zombies: + request = TaskCallbackRequest( + full_filepath=file_loc, + simple_task_instance=SimpleTaskInstance(ti), + msg=f"Detected {ti} as zombie", + ) + self.log.error("Detected zombie job: %s", request) + self.processor_agent.send_callback_to_execute(request) + Stats.incr('zombies_killed') diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index eac62dcd39506..0136d7fc13fd9 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -89,7 +89,6 @@ class DagBag(LoggingMixin): """ DAGBAG_IMPORT_TIMEOUT = conf.getfloat('core', 'DAGBAG_IMPORT_TIMEOUT') - SCHEDULER_ZOMBIE_TASK_THRESHOLD = conf.getint('scheduler', 'scheduler_zombie_task_threshold') def __init__( self, diff --git a/tests/dag_processing/test_manager.py b/tests/dag_processing/test_manager.py index 5ea21a216b176..2746e5963806f 100644 --- a/tests/dag_processing/test_manager.py +++ b/tests/dag_processing/test_manager.py @@ -45,17 +45,13 @@ DagParsingStat, ) from airflow.dag_processing.processor import DagFileProcessorProcess -from airflow.jobs.local_task_job import LocalTaskJob as LJ -from airflow.models import DagBag, DagModel, TaskInstance as TI, errors +from airflow.models import DagBag, DagModel, errors from airflow.models.dagcode import DagCode from airflow.models.serialized_dag import SerializedDagModel -from airflow.models.taskinstance import SimpleTaskInstance from airflow.utils import timezone -from airflow.utils.callback_requests import CallbackRequest, TaskCallbackRequest +from airflow.utils.callback_requests import CallbackRequest from airflow.utils.net import get_hostname from airflow.utils.session import create_session -from airflow.utils.state import DagRunState, State -from airflow.utils.types import DagRunType from tests.core.test_logging_config import SETTINGS_FILE_VALID, settings_context from tests.models import TEST_DAGS_FOLDER from tests.test_utils.config import conf_vars @@ -455,147 +451,6 @@ def test_recently_modified_file_is_parsed_with_mtime_mode( > (freezed_base_time - manager.get_last_finish_time("file_1.py")).total_seconds() ) - def test_find_zombies(self): - manager = DagFileProcessorManager( - dag_directory='directory', - max_runs=1, - processor_timeout=timedelta.max, - signal_conn=MagicMock(), - dag_ids=[], - pickle_dags=False, - async_mode=True, - ) - - dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False) - with create_session() as session: - session.query(LJ).delete() - dag = dagbag.get_dag('example_branch_operator') - dag.sync_to_db() - task = dag.get_task(task_id='run_this_first') - - dag_run = dag.create_dagrun( - state=DagRunState.RUNNING, - execution_date=DEFAULT_DATE, - run_type=DagRunType.SCHEDULED, - session=session, - ) - - ti = TI(task, run_id=dag_run.run_id, state=State.RUNNING) - local_job = LJ(ti) - local_job.state = State.SHUTDOWN - - session.add(local_job) - session.flush() - - ti.job_id = local_job.id - session.add(ti) - session.flush() - - manager._last_zombie_query_time = timezone.utcnow() - timedelta( - seconds=manager._zombie_threshold_secs + 1 - ) - manager._find_zombies() - requests = manager._callback_to_execute[dag.fileloc] - assert 1 == len(requests) - assert requests[0].full_filepath == dag.fileloc - assert requests[0].msg == f"Detected {ti} as zombie" - assert requests[0].is_failure_callback is True - assert isinstance(requests[0].simple_task_instance, SimpleTaskInstance) - assert ti.dag_id == requests[0].simple_task_instance.dag_id - assert ti.task_id == requests[0].simple_task_instance.task_id - assert ti.run_id == requests[0].simple_task_instance.run_id - - session.query(TI).delete() - session.query(LJ).delete() - - @mock.patch('airflow.dag_processing.manager.DagFileProcessorProcess') - def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_processor( - self, mock_processor - ): - """ - Check that the same set of failure callback with zombies are passed to the dag - file processors until the next zombie detection logic is invoked. - """ - test_dag_path = TEST_DAG_FOLDER / 'test_example_bash_operator.py' - with conf_vars({('scheduler', 'parsing_processes'): '1', ('core', 'load_examples'): 'False'}): - dagbag = DagBag(test_dag_path, read_dags_from_db=False) - with create_session() as session: - session.query(LJ).delete() - dag = dagbag.get_dag('test_example_bash_operator') - dag.sync_to_db() - - dag_run = dag.create_dagrun( - state=DagRunState.RUNNING, - execution_date=DEFAULT_DATE, - run_type=DagRunType.SCHEDULED, - session=session, - ) - task = dag.get_task(task_id='run_this_last') - - ti = TI(task, run_id=dag_run.run_id, state=State.RUNNING) - local_job = LJ(ti) - local_job.state = State.SHUTDOWN - session.add(local_job) - session.flush() - - # TODO: If there was an actual Relationship between TI and Job - # we wouldn't need this extra commit - session.add(ti) - ti.job_id = local_job.id - session.flush() - - expected_failure_callback_requests = [ - TaskCallbackRequest( - full_filepath=dag.fileloc, - simple_task_instance=SimpleTaskInstance(ti), - msg="Message", - ) - ] - - test_dag_path = TEST_DAG_FOLDER / 'test_example_bash_operator.py' - - child_pipe, parent_pipe = multiprocessing.Pipe() - async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn') - - fake_processors = [] - - def fake_processor_(*args, **kwargs): - nonlocal fake_processors - processor = FakeDagFileProcessorRunner._create_process(*args, **kwargs) - fake_processors.append(processor) - return processor - - mock_processor.side_effect = fake_processor_ - - manager = DagFileProcessorManager( - dag_directory=test_dag_path, - max_runs=1, - processor_timeout=timedelta.max, - signal_conn=child_pipe, - dag_ids=[], - pickle_dags=False, - async_mode=async_mode, - ) - - self.run_processor_manager_one_loop(manager, parent_pipe) - - if async_mode: - # Once for initial parse, and then again for the add_callback_to_queue - assert len(fake_processors) == 2 - assert fake_processors[0]._file_path == str(test_dag_path) - assert fake_processors[0]._callback_requests == [] - else: - assert len(fake_processors) == 1 - - assert fake_processors[-1]._file_path == str(test_dag_path) - callback_requests = fake_processors[-1]._callback_requests - assert {zombie.simple_task_instance.key for zombie in expected_failure_callback_requests} == { - result.simple_task_instance.key for result in callback_requests - } - - child_pipe.close() - parent_pipe.close() - @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess.pid", new_callable=PropertyMock) @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess.kill") def test_kill_timed_out_processors_kill(self, mock_kill, mock_pid): diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 707f587223f0a..845ffda016ac1 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -40,16 +40,17 @@ from airflow.executors.base_executor import BaseExecutor from airflow.jobs.backfill_job import BackfillJob from airflow.jobs.base_job import BaseJob +from airflow.jobs.local_task_job import LocalTaskJob from airflow.jobs.scheduler_job import SchedulerJob from airflow.models import DAG, DagBag, DagModel, Pool, TaskInstance from airflow.models.dagrun import DagRun from airflow.models.serialized_dag import SerializedDagModel -from airflow.models.taskinstance import TaskInstanceKey +from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey from airflow.operators.bash import BashOperator from airflow.operators.dummy import DummyOperator from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils import timezone -from airflow.utils.callback_requests import DagCallbackRequest +from airflow.utils.callback_requests import DagCallbackRequest, TaskCallbackRequest from airflow.utils.file import list_py_file_paths from airflow.utils.session import create_session, provide_session from airflow.utils.state import DagRunState, State, TaskInstanceState @@ -3480,6 +3481,114 @@ def test_timeout_triggers(self, dag_maker): assert ti1.next_method == "__fail__" assert ti2.state == State.DEFERRED + def test_find_zombies_nothing(self): + with create_session() as session: + self.scheduler_job = SchedulerJob() + self.scheduler_job.processor_agent = mock.MagicMock() + + self.scheduler_job._find_zombies(session=session) + + self.scheduler_job.processor_agent.send_callback_to_execute.assert_not_called() + + def test_find_zombies(self): + dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False) + with create_session() as session: + session.query(LocalTaskJob).delete() + dag = dagbag.get_dag('example_branch_operator') + dag.sync_to_db() + task = dag.get_task(task_id='run_this_first') + + dag_run = dag.create_dagrun( + state=DagRunState.RUNNING, + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, + ) + + ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING) + local_job = LocalTaskJob(ti) + local_job.state = State.SHUTDOWN + + session.add(local_job) + session.flush() + + ti.job_id = local_job.id + session.add(ti) + session.flush() + + self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.processor_agent = mock.MagicMock() + + self.scheduler_job._find_zombies(session=session) + + self.scheduler_job.processor_agent.send_callback_to_execute.assert_called_once() + requests = self.scheduler_job.processor_agent.send_callback_to_execute.call_args[0] + assert 1 == len(requests) + assert requests[0].full_filepath == dag.fileloc + assert requests[0].msg == f"Detected {ti} as zombie" + assert requests[0].is_failure_callback is True + assert isinstance(requests[0].simple_task_instance, SimpleTaskInstance) + assert ti.dag_id == requests[0].simple_task_instance.dag_id + assert ti.task_id == requests[0].simple_task_instance.task_id + assert ti.run_id == requests[0].simple_task_instance.run_id + + session.query(TaskInstance).delete() + session.query(LocalTaskJob).delete() + + def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_processor(self): + """ + Check that the same set of failure callback with zombies are passed to the dag + file processors until the next zombie detection logic is invoked. + """ + with conf_vars({('core', 'load_examples'): 'False'}): + dagbag = DagBag( + dag_folder=os.path.join(settings.DAGS_FOLDER, "test_example_bash_operator.py"), + read_dags_from_db=False, + ) + session = settings.Session() + session.query(LocalTaskJob).delete() + dag = dagbag.get_dag('test_example_bash_operator') + dag.sync_to_db() + + dag_run = dag.create_dagrun( + state=DagRunState.RUNNING, + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, + ) + task = dag.get_task(task_id='run_this_last') + + ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING) + local_job = LocalTaskJob(ti) + local_job.state = State.SHUTDOWN + session.add(local_job) + session.flush() + + # TODO: If there was an actual Relationship between TI and Job + # we wouldn't need this extra commit + session.add(ti) + ti.job_id = local_job.id + session.flush() + + expected_failure_callback_requests = [ + TaskCallbackRequest( + full_filepath=dag.fileloc, + simple_task_instance=SimpleTaskInstance(ti), + msg="Message", + ) + ] + + self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.processor_agent = mock.MagicMock() + + self.scheduler_job._find_zombies(session=session) + + self.scheduler_job.processor_agent.send_callback_to_execute.assert_called_once() + callback_requests = self.scheduler_job.processor_agent.send_callback_to_execute.call_args[0] + assert {zombie.simple_task_instance.key for zombie in expected_failure_callback_requests} == { + result.simple_task_instance.key for result in callback_requests + } + @pytest.mark.xfail(reason="Work out where this goes") def test_task_with_upstream_skip_process_task_instances(): From 862ec769544ef360bb8ed6a1cf68a6d5949ea546 Mon Sep 17 00:00:00 2001 From: Igor Kholopov Date: Thu, 10 Feb 2022 10:53:59 +0100 Subject: [PATCH 09/19] Modernize DAG-related URL routes and rename "tree" to "grid" (#20730) Co-authored-by: Igor Kholopov --- airflow/www/decorators.py | 13 +- airflow/www/templates/airflow/dag.html | 10 +- airflow/www/templates/airflow/dags.html | 8 +- airflow/www/utils.py | 4 +- airflow/www/views.py | 169 ++++++++++++++++++++--- tests/www/views/test_views_decorators.py | 20 ++- tests/www/views/test_views_tasks.py | 85 ++++++++++-- 7 files changed, 265 insertions(+), 44 deletions(-) diff --git a/airflow/www/decorators.py b/airflow/www/decorators.py index 080fe682991c2..47a9847c2eecf 100644 --- a/airflow/www/decorators.py +++ b/airflow/www/decorators.py @@ -20,6 +20,7 @@ import gzip import logging from io import BytesIO as IO +from itertools import chain from typing import Callable, TypeVar, cast import pendulum @@ -48,13 +49,19 @@ def wrapper(*args, **kwargs): user = g.user.username fields_skip_logging = {'csrf_token', '_csrf_token'} + log_fields = { + k: v + for k, v in chain(request.values.items(), request.view_args.items()) + if k not in fields_skip_logging + } + log = Log( event=f.__name__, task_instance=None, owner=user, - extra=str([(k, v) for k, v in request.values.items() if k not in fields_skip_logging]), - task_id=request.values.get('task_id'), - dag_id=request.values.get('dag_id'), + extra=str([(k, log_fields[k]) for k in log_fields]), + task_id=log_fields.get('task_id'), + dag_id=log_fields.get('dag_id'), ) if 'execution_date' in request.values: diff --git a/airflow/www/templates/airflow/dag.html b/airflow/www/templates/airflow/dag.html index e57a1f1fc8b3b..0f4b967373da5 100644 --- a/airflow/www/templates/airflow/dag.html +++ b/airflow/www/templates/airflow/dag.html @@ -110,9 +110,9 @@