diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 259303c3073e8..a73abec8487f8 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -232,7 +232,7 @@ repos:
name: Run pydocstyle
args:
- --convention=pep257
- - --add-ignore=D100,D102,D103,D104,D105,D107,D205,D400,D401
+ - --add-ignore=D100,D102,D103,D104,D105,D107,D202,D205,D400,D401
exclude: |
(?x)
^tests/.*\.py$|
diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml
index 8ef738f4f4ad6..83ea6b94ccff5 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -1850,6 +1850,13 @@
type: string
example: ~
default: "300"
+ - name: zombie_detection_interval
+ description: |
+ How often (in seconds) should the scheduler check for zombie tasks.
+ version_added: 2.3.0
+ type: float
+ example: ~
+ default: "10.0"
- name: catchup_by_default
description: |
Turn off scheduler catchup by setting this to ``False``.
diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg
index 520ab4442850d..55161a55d5e71 100644
--- a/airflow/config_templates/default_airflow.cfg
+++ b/airflow/config_templates/default_airflow.cfg
@@ -923,6 +923,9 @@ child_process_log_directory = {AIRFLOW_HOME}/logs/scheduler
# associated task instance as failed and will re-schedule the task.
scheduler_zombie_task_threshold = 300
+# How often (in seconds) should the scheduler check for zombie tasks.
+zombie_detection_interval = 10.0
+
# Turn off scheduler catchup by setting this to ``False``.
# Default behavior is unchanged and
# Command Line Backfills still work, but the scheduler
diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py
index 3b8a998551aca..33b219ccd58ea 100644
--- a/airflow/dag_processing/manager.py
+++ b/airflow/dag_processing/manager.py
@@ -34,7 +34,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union, cast
from setproctitle import setproctitle
-from sqlalchemy import or_
from tabulate import tabulate
import airflow.models
@@ -42,17 +41,15 @@
from airflow.dag_processing.processor import DagFileProcessorProcess
from airflow.models import DagModel, errors
from airflow.models.serialized_dag import SerializedDagModel
-from airflow.models.taskinstance import SimpleTaskInstance
from airflow.stats import Stats
from airflow.utils import timezone
-from airflow.utils.callback_requests import CallbackRequest, SlaCallbackRequest, TaskCallbackRequest
+from airflow.utils.callback_requests import CallbackRequest, SlaCallbackRequest
from airflow.utils.file import list_py_file_paths, might_contain_dag
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.mixins import MultiprocessingStartMethodMixin
from airflow.utils.net import get_hostname
from airflow.utils.process_utils import kill_child_processes_by_pids, reap_process_group
from airflow.utils.session import provide_session
-from airflow.utils.state import State
if TYPE_CHECKING:
import pathlib
@@ -434,8 +431,6 @@ def __init__(
# How often to print out DAG file processing stats to the log. Default to
# 30 seconds.
self.print_stats_interval = conf.getint('scheduler', 'print_stats_interval')
- # How many seconds do we wait for tasks to heartbeat before mark them as zombies.
- self._zombie_threshold_secs = conf.getint('scheduler', 'scheduler_zombie_task_threshold')
# Map from file path to the processor
self._processors: Dict[str, DagFileProcessorProcess] = {}
@@ -445,13 +440,10 @@ def __init__(
# Map from file path to stats about the file
self._file_stats: Dict[str, DagFileStat] = {}
- self._last_zombie_query_time = None
# Last time that the DAG dir was traversed to look for files
self.last_dag_dir_refresh_time = timezone.make_aware(datetime.fromtimestamp(0))
# Last time stats were printed
self.last_stat_print_time = 0
- # TODO: Remove magic number
- self._zombie_query_interval = 10
# How long to wait before timing out a process to parse a DAG file
self._processor_timeout = processor_timeout
@@ -566,7 +558,6 @@ def _run_parsing_loop(self):
self._processors.pop(processor.file_path)
self._refresh_dag_dir()
- self._find_zombies()
self._kill_timed_out_processors()
@@ -1023,53 +1014,6 @@ def prepare_file_path_queue(self):
self._file_path_queue.extend(files_paths_to_queue)
- @provide_session
- def _find_zombies(self, session):
- """
- Find zombie task instances, which are tasks haven't heartbeated for too long
- and update the current zombie list.
- """
- now = timezone.utcnow()
- if (
- not self._last_zombie_query_time
- or (now - self._last_zombie_query_time).total_seconds() > self._zombie_query_interval
- ):
- # to avoid circular imports
- from airflow.jobs.local_task_job import LocalTaskJob as LJ
-
- self.log.info("Finding 'running' jobs without a recent heartbeat")
- TI = airflow.models.TaskInstance
- DM = airflow.models.DagModel
- limit_dttm = timezone.utcnow() - timedelta(seconds=self._zombie_threshold_secs)
-
- zombies = (
- session.query(TI, DM.fileloc)
- .join(LJ, TI.job_id == LJ.id)
- .join(DM, TI.dag_id == DM.dag_id)
- .filter(TI.state == State.RUNNING)
- .filter(
- or_(
- LJ.state != State.RUNNING,
- LJ.latest_heartbeat < limit_dttm,
- )
- )
- .all()
- )
-
- if zombies:
- self.log.warning("Failing (%s) jobs without heartbeat after %s", len(zombies), limit_dttm)
-
- self._last_zombie_query_time = timezone.utcnow()
- for ti, file_loc in zombies:
- request = TaskCallbackRequest(
- full_filepath=file_loc,
- simple_task_instance=SimpleTaskInstance(ti),
- msg=f"Detected {ti} as zombie",
- )
- self.log.error("Detected zombie job: %s", request)
- self._add_callback_to_queue(request)
- Stats.incr('zombies_killed')
-
def _kill_timed_out_processors(self):
"""Kill any file processors that timeout to defend against process hangs."""
now = timezone.utcnow()
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 9cf423fb69e11..53a12c62d2318 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -280,30 +280,88 @@ def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any], valid_names
names = ", ".join(repr(n) for n in unknown_args)
raise TypeError(f'{funcname} got unexpected keyword arguments {names}')
- def map(
- self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs
- ) -> XComArg:
+ def map(self, *args, **kwargs) -> XComArg:
self._validate_arg_names("map", kwargs)
- dag = dag or DagContext.get_current_dag()
- task_group = task_group or TaskGroupContext.get_current_task_group(dag)
- task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group)
- operator = MappedOperator.from_decorator(
- decorator=self,
+ partial_kwargs = self.kwargs.copy()
+ dag = partial_kwargs.pop("dag", DagContext.get_current_dag())
+ task_group = partial_kwargs.pop("task_group", TaskGroupContext.get_current_task_group(dag))
+ task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group)
+
+ # Unfortunately attrs's type hinting support does not work well with
+ # subclassing; it complains that arguments forwarded to the superclass
+ # are "unexpected" (they are fine at runtime).
+ operator = cast(Any, DecoratedMappedOperator)(
+ operator_class=self.operator_class,
+ partial_kwargs=partial_kwargs,
+ mapped_kwargs={},
+ task_id=task_id,
dag=dag,
task_group=task_group,
- task_id=task_id,
- mapped_kwargs=kwargs,
+ deps=MappedOperator._deps(self.operator_class.deps),
+ multiple_outputs=self.multiple_outputs,
+ python_callable=self.function,
)
+
+ operator.mapped_kwargs["op_args"] = list(args)
+ operator.mapped_kwargs["op_kwargs"] = kwargs
+
+ for arg in itertools.chain(args, kwargs.values()):
+ XComArg.apply_upstream_relationship(operator, arg)
return XComArg(operator=operator)
- def partial(
- self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs
- ) -> "_TaskDecorator[Function, OperatorSubclass]":
- self._validate_arg_names("partial", kwargs, {'task_id'})
- partial_kwargs = self.kwargs.copy()
- partial_kwargs.update(kwargs)
- return attr.evolve(self, kwargs=partial_kwargs)
+ def partial(self, *args, **kwargs) -> "_TaskDecorator[Function, OperatorSubclass]":
+ self._validate_arg_names("partial", kwargs)
+
+ op_args = self.kwargs.get("op_args", [])
+ op_args.extend(args)
+
+ op_kwargs = self.kwargs.get("op_kwargs", {})
+ op_kwargs = _merge_kwargs(op_kwargs, kwargs, fail_reason="duplicate partial")
+
+ return attr.evolve(self, kwargs={**self.kwargs, "op_args": op_args, "op_kwargs": op_kwargs})
+
+
+def _merge_kwargs(
+ kwargs1: Dict[str, XComArg],
+ kwargs2: Dict[str, XComArg],
+ *,
+ fail_reason: str,
+) -> Dict[str, XComArg]:
+ duplicated_keys = set(kwargs1).intersection(kwargs2)
+ if len(duplicated_keys) == 1:
+ raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}")
+ elif duplicated_keys:
+ duplicated_keys_display = ", ".join(sorted(duplicated_keys))
+ raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}")
+ return {**kwargs1, **kwargs2}
+
+
+@attr.define(kw_only=True)
+class DecoratedMappedOperator(MappedOperator):
+ """MappedOperator implementation for @task-decorated task function."""
+
+ multiple_outputs: bool
+ python_callable: Callable
+
+ def create_unmapped_operator(self, dag: "DAG") -> BaseOperator:
+ assert not isinstance(self.operator_class, str)
+ op_args = self.partial_kwargs.pop("op_args", []) + self.mapped_kwargs.pop("op_args", [])
+ op_kwargs = _merge_kwargs(
+ self.partial_kwargs.pop("op_kwargs", {}),
+ self.mapped_kwargs.pop("op_kwargs", {}),
+ fail_reason="mapping already partial",
+ )
+ return self.operator_class(
+ dag=dag,
+ task_id=self.task_id,
+ op_args=op_args,
+ op_kwargs=op_kwargs,
+ multiple_outputs=self.multiple_outputs,
+ python_callable=self.python_callable,
+ **self.partial_kwargs,
+ **self.mapped_kwargs,
+ )
class Task(Generic[Function]):
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 7a6e3efd74ba9..62116553c675e 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -38,6 +38,7 @@
from airflow.dag_processing.manager import DagFileProcessorAgent
from airflow.executors.executor_loader import UNPICKLEABLE_EXECUTORS
from airflow.jobs.base_job import BaseJob
+from airflow.jobs.local_task_job import LocalTaskJob
from airflow.models import DAG
from airflow.models.dag import DagModel
from airflow.models.dagbag import DagBag
@@ -123,6 +124,8 @@ def __init__(
)
scheduler_idle_sleep_time = processor_poll_interval
self._scheduler_idle_sleep_time = scheduler_idle_sleep_time
+ # How many seconds do we wait for tasks to heartbeat before mark them as zombies.
+ self._zombie_threshold_secs = conf.getint('scheduler', 'scheduler_zombie_task_threshold')
self.do_pickle = do_pickle
super().__init__(*args, **kwargs)
@@ -739,6 +742,11 @@ def _run_scheduler_loop(self) -> None:
self._emit_pool_metrics,
)
+ timers.call_regular_interval(
+ conf.getfloat('scheduler', 'zombie_detection_interval', fallback=10.0),
+ self._find_zombies,
+ )
+
for loop_count in itertools.count(start=1):
with Stats.timer() as timer:
@@ -1259,3 +1267,39 @@ def check_trigger_timeouts(self, session: Session = None):
)
if num_timed_out_tasks:
self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks)
+
+ @provide_session
+ def _find_zombies(self, session):
+ """
+ Find zombie task instances, which are tasks haven't heartbeated for too long
+ and update the current zombie list.
+ """
+ self.log.debug("Finding 'running' jobs without a recent heartbeat")
+ limit_dttm = timezone.utcnow() - timedelta(seconds=self._zombie_threshold_secs)
+
+ zombies = (
+ session.query(TaskInstance, DagModel.fileloc)
+ .join(LocalTaskJob, TaskInstance.job_id == LocalTaskJob.id)
+ .join(DagModel, TaskInstance.dag_id == DagModel.dag_id)
+ .filter(TaskInstance.state == State.RUNNING)
+ .filter(
+ or_(
+ LocalTaskJob.state != State.RUNNING,
+ LocalTaskJob.latest_heartbeat < limit_dttm,
+ )
+ )
+ .all()
+ )
+
+ if zombies:
+ self.log.warning("Failing (%s) jobs without heartbeat after %s", len(zombies), limit_dttm)
+
+ for ti, file_loc in zombies:
+ request = TaskCallbackRequest(
+ full_filepath=file_loc,
+ simple_task_instance=SimpleTaskInstance(ti),
+ msg=f"Detected {ti} as zombie",
+ )
+ self.log.error("Detected zombie job: %s", request)
+ self.processor_agent.send_callback_to_execute(request)
+ Stats.incr('zombies_killed')
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 34c84128391dc..35a0fbb53fe8a 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -82,7 +82,6 @@
from airflow.utils.weight_rule import WeightRule
if TYPE_CHECKING:
- from airflow.decorators.base import _TaskDecorator
from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup
@@ -243,7 +242,7 @@ def __new__(cls, name, bases, namespace, **kwargs):
return new_cls
# The class level partial function. This is what handles the actual mapping
- def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs):
+ def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs) -> "MappedOperator":
operator_class = cast("Type[BaseOperator]", cls)
# Validate that the args we passed are known -- at call/DAG parse time, not run time!
_validate_kwarg_names_for_mapping(operator_class, "partial", kwargs)
@@ -671,7 +670,7 @@ def __init__(
)
self.trigger_rule = trigger_rule
- self.depends_on_past = depends_on_past
+ self.depends_on_past: bool = depends_on_past
self.wait_for_downstream = wait_for_downstream
if wait_for_downstream:
self.depends_on_past = True
@@ -714,7 +713,7 @@ def __init__(
stacklevel=2,
)
max_active_tis_per_dag = task_concurrency
- self.max_active_tis_per_dag = max_active_tis_per_dag
+ self.max_active_tis_per_dag: Optional[int] = max_active_tis_per_dag
self.do_xcom_push = do_xcom_push
self.doc_md = doc_md
@@ -1632,7 +1631,7 @@ def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) ->
dag._remove_task(operator.task_id)
operator_init_kwargs: dict = operator._BaseOperator__init_kwargs # type: ignore
- return MappedOperator(
+ return cls(
operator_class=type(operator),
task_id=operator.task_id,
task_group=task_group,
@@ -1648,37 +1647,6 @@ def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) ->
deps=cls._deps(operator.deps),
)
- @classmethod
- def from_decorator(
- cls,
- *,
- decorator: "_TaskDecorator",
- dag: Optional["DAG"],
- task_group: Optional["TaskGroup"],
- task_id: str,
- mapped_kwargs: Dict[str, Any],
- ) -> "MappedOperator":
- """Create a mapped operator from a task decorator.
-
- Different from ``from_operator``, this DOES NOT validate ``mapped_kwargs``.
- The task decorator calling this should be responsible for validation.
- """
- from airflow.models.xcom_arg import XComArg
-
- operator = MappedOperator(
- operator_class=decorator.operator_class,
- partial_kwargs=decorator.kwargs,
- mapped_kwargs={},
- task_id=task_id,
- dag=dag,
- task_group=task_group,
- deps=cls._deps(decorator.operator_class.deps),
- )
- operator.mapped_kwargs.update(mapped_kwargs)
- for arg in mapped_kwargs.values():
- XComArg.apply_upstream_relationship(operator, arg)
- return operator
-
@classmethod
def _deps(cls, deps: Iterable[BaseTIDep]):
if deps is BaseOperator.deps:
@@ -1749,7 +1717,7 @@ def inherits_from_dummy_operator(self):
@classmethod
def get_serialized_fields(cls):
if cls.__serialized_fields is None:
- fields_dict = attr.fields_dict(cls)
+ fields_dict = attr.fields_dict(MappedOperator)
cls.__serialized_fields = frozenset(
fields_dict.keys()
- {
@@ -1902,22 +1870,17 @@ def expand_mapped_task(
return ret
- def unmap(self) -> BaseOperator:
- """Get the "normal" Operator after applying the current mapping"""
+ def create_unmapped_operator(self, dag: "DAG") -> BaseOperator:
assert not isinstance(self.operator_class, str)
+ return self.operator_class(dag=dag, task_id=self.task_id, **self.partial_kwargs, **self.mapped_kwargs)
+ def unmap(self) -> BaseOperator:
+ """Get the "normal" Operator after applying the current mapping"""
dag = self.get_dag()
if not dag:
- raise RuntimeError("Cannot unmapp a task unless it has a dag")
-
- args = {
- **self.partial_kwargs,
- **self.mapped_kwargs,
- }
+ raise RuntimeError("Cannot unmap a task unless it has a DAG")
dag._remove_task(self.task_id)
- task = self.operator_class(task_id=self.task_id, dag=self.dag, **args)
-
- return task
+ return self.create_unmapped_operator(dag)
# TODO: Deprecate for Airflow 3.0
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index eac62dcd39506..0136d7fc13fd9 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -89,7 +89,6 @@ class DagBag(LoggingMixin):
"""
DAGBAG_IMPORT_TIMEOUT = conf.getfloat('core', 'DAGBAG_IMPORT_TIMEOUT')
- SCHEDULER_ZOMBIE_TASK_THRESHOLD = conf.getint('scheduler', 'scheduler_zombie_task_threshold')
def __init__(
self,
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 4996b9a7db073..f10032dfc2964 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -1713,7 +1713,7 @@ def handle_failure(
test_mode: Optional[bool] = None,
force_fail: bool = False,
error_file: Optional[str] = None,
- session=NEW_SESSION,
+ session: Session = NEW_SESSION,
) -> None:
"""Handle Failure for the TaskInstance"""
if test_mode is None:
diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_templated.py b/airflow/providers/amazon/aws/example_dags/example_eks_templated.py
index afd19912aa1a9..26e10d7e5f4ce 100644
--- a/airflow/providers/amazon/aws/example_dags/example_eks_templated.py
+++ b/airflow/providers/amazon/aws/example_dags/example_eks_templated.py
@@ -14,11 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-# Ignore missing args provided by default_args
-# type: ignore[call-arg]
-
-import os
+import json
from datetime import datetime
from airflow.models.dag import DAG
@@ -37,7 +33,6 @@
{
"cluster_name": "templated-cluster",
"cluster_role_arn": "arn:aws:iam::123456789012:role/role_name",
- "nodegroup_subnets": ["subnet-12345ab", "subnet-67890cd"],
"resources_vpc_config": {
"subnetIds": ["subnet-12345ab", "subnet-67890cd"],
"endpointPublicAccess": true,
@@ -49,25 +44,24 @@
"""
with DAG(
- dag_id='to-publish-manuals-templated',
- default_args={'cluster_name': "{{ dag_run.conf['cluster_name'] }}"},
+ dag_id='example_eks_templated',
schedule_interval=None,
start_date=datetime(2021, 1, 1),
- catchup=False,
- max_active_runs=1,
tags=['example', 'templated'],
+ catchup=False,
# render_template_as_native_obj=True is what converts the Jinja to Python objects, instead of a string.
render_template_as_native_obj=True,
) as dag:
- SUBNETS = os.environ.get('EKS_DEMO_SUBNETS', 'subnet-12345ab subnet-67890cd').split(' ')
- VPC_CONFIG = {
- 'subnetIds': SUBNETS,
- 'endpointPublicAccess': True,
- 'endpointPrivateAccess': False,
- }
+
+ CLUSTER_NAME = "{{ dag_run.conf['cluster_name'] }}"
+ NODEGROUP_NAME = "{{ dag_run.conf['nodegroup_name'] }}"
+ VPC_CONFIG = json.loads("{{ dag_run.conf['resources_vpc_config'] }}")
+ SUBNETS = VPC_CONFIG['subnetIds']
+
# Create an Amazon EKS Cluster control plane without attaching a compute service.
create_cluster = EksCreateClusterOperator(
task_id='create_eks_cluster',
+ cluster_name=CLUSTER_NAME,
compute=None,
cluster_role_arn="{{ dag_run.conf['cluster_role_arn'] }}",
resources_vpc_config=VPC_CONFIG,
@@ -75,24 +69,28 @@
await_create_cluster = EksClusterStateSensor(
task_id='wait_for_create_cluster',
+ cluster_name=CLUSTER_NAME,
target_state=ClusterStates.ACTIVE,
)
create_nodegroup = EksCreateNodegroupOperator(
task_id='create_eks_nodegroup',
- nodegroup_name="{{ dag_run.conf['nodegroup_name'] }}",
- nodegroup_subnets="{{ dag_run.conf['nodegroup_subnets'] }}",
+ cluster_name=CLUSTER_NAME,
+ nodegroup_name=NODEGROUP_NAME,
+ nodegroup_subnets=SUBNETS,
nodegroup_role_arn="{{ dag_run.conf['nodegroup_role_arn'] }}",
)
await_create_nodegroup = EksNodegroupStateSensor(
task_id='wait_for_create_nodegroup',
- nodegroup_name="{{ dag_run.conf['nodegroup_name'] }}",
+ cluster_name=CLUSTER_NAME,
+ nodegroup_name=NODEGROUP_NAME,
target_state=NodegroupStates.ACTIVE,
)
start_pod = EksPodOperator(
task_id="run_pod",
+ cluster_name=CLUSTER_NAME,
pod_name="run_pod",
image="amazon/aws-cli:latest",
cmds=["sh", "-c", "ls"],
@@ -104,21 +102,25 @@
delete_nodegroup = EksDeleteNodegroupOperator(
task_id='delete_eks_nodegroup',
- nodegroup_name="{{ dag_run.conf['nodegroup_name'] }}",
+ cluster_name=CLUSTER_NAME,
+ nodegroup_name=NODEGROUP_NAME,
)
await_delete_nodegroup = EksNodegroupStateSensor(
task_id='wait_for_delete_nodegroup',
- nodegroup_name="{{ dag_run.conf['nodegroup_name'] }}",
+ cluster_name=CLUSTER_NAME,
+ nodegroup_name=NODEGROUP_NAME,
target_state=NodegroupStates.NONEXISTENT,
)
delete_cluster = EksDeleteClusterOperator(
task_id='delete_eks_cluster',
+ cluster_name=CLUSTER_NAME,
)
await_delete_cluster = EksClusterStateSensor(
task_id='wait_for_delete_cluster',
+ cluster_name=CLUSTER_NAME,
target_state=ClusterStates.NONEXISTENT,
)
diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py
index 4107058b5ac20..e08e6525e6fc8 100644
--- a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py
+++ b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py
@@ -43,19 +43,18 @@
with DAG(
- dag_id='example-create-cluster-and-fargate-all-in-one',
- default_args={'cluster_name': CLUSTER_NAME},
+ dag_id='example_eks_with_fargate_in_one_step',
schedule_interval=None,
start_date=datetime(2021, 1, 1),
- catchup=False,
- max_active_runs=1,
tags=['example'],
+ catchup=False,
) as dag:
# [START howto_operator_eks_create_cluster_with_fargate_profile]
# Create an Amazon EKS cluster control plane and an AWS Fargate compute platform in one step.
create_cluster_and_fargate_profile = EksCreateClusterOperator(
task_id='create_eks_cluster_and_fargate_profile',
+ cluster_name=CLUSTER_NAME,
cluster_role_arn=ROLE_ARN,
resources_vpc_config=VPC_CONFIG,
compute='fargate',
@@ -68,6 +67,7 @@
await_create_fargate_profile = EksFargateProfileStateSensor(
task_id='wait_for_create_fargate_profile',
+ cluster_name=CLUSTER_NAME,
fargate_profile_name=FARGATE_PROFILE_NAME,
target_state=FargateProfileStates.ACTIVE,
)
@@ -75,6 +75,7 @@
start_pod = EksPodOperator(
task_id="run_pod",
pod_name="run_pod",
+ cluster_name=CLUSTER_NAME,
image="amazon/aws-cli:latest",
cmds=["sh", "-c", "echo Test Airflow; date"],
labels={"demo": "hello_world"},
@@ -86,11 +87,14 @@
# An Amazon EKS cluster can not be deleted with attached resources such as nodegroups or Fargate profiles.
# Setting the `force` to `True` will delete any attached resources before deleting the cluster.
delete_all = EksDeleteClusterOperator(
- task_id='delete_fargate_profile_and_cluster', force_delete_compute=True
+ task_id='delete_fargate_profile_and_cluster',
+ cluster_name=CLUSTER_NAME,
+ force_delete_compute=True,
)
await_delete_cluster = EksClusterStateSensor(
task_id='wait_for_delete_cluster',
+ cluster_name=CLUSTER_NAME,
target_state=ClusterStates.NONEXISTENT,
)
diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py
index e58e2de729b62..3ca3b2eb87728 100644
--- a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py
+++ b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py
@@ -46,18 +46,17 @@
with DAG(
- dag_id='example_eks_with_fargate_profile_dag',
- default_args={'cluster_name': CLUSTER_NAME},
+ dag_id='example_eks_with_fargate_profile',
schedule_interval=None,
start_date=datetime(2021, 1, 1),
- catchup=False,
- max_active_runs=1,
tags=['example'],
+ catchup=False,
) as dag:
# Create an Amazon EKS Cluster control plane without attaching a compute service.
create_cluster = EksCreateClusterOperator(
task_id='create_eks_cluster',
+ cluster_name=CLUSTER_NAME,
cluster_role_arn=ROLE_ARN,
resources_vpc_config=VPC_CONFIG,
compute=None,
@@ -65,26 +64,32 @@
await_create_cluster = EksClusterStateSensor(
task_id='wait_for_create_cluster',
+ cluster_name=CLUSTER_NAME,
target_state=ClusterStates.ACTIVE,
)
# [START howto_operator_eks_create_fargate_profile]
create_fargate_profile = EksCreateFargateProfileOperator(
task_id='create_eks_fargate_profile',
+ cluster_name=CLUSTER_NAME,
pod_execution_role_arn=ROLE_ARN,
fargate_profile_name=FARGATE_PROFILE_NAME,
selectors=SELECTORS,
)
# [END howto_operator_eks_create_fargate_profile]
+ # [START howto_sensor_eks_fargate]
await_create_fargate_profile = EksFargateProfileStateSensor(
task_id='wait_for_create_fargate_profile',
+ cluster_name=CLUSTER_NAME,
fargate_profile_name=FARGATE_PROFILE_NAME,
target_state=FargateProfileStates.ACTIVE,
)
+ # [END howto_sensor_eks_fargate]
start_pod = EksPodOperator(
task_id="run_pod",
+ cluster_name=CLUSTER_NAME,
pod_name="run_pod",
image="amazon/aws-cli:latest",
cmds=["sh", "-c", "echo Test Airflow; date"],
@@ -97,20 +102,26 @@
# [START howto_operator_eks_delete_fargate_profile]
delete_fargate_profile = EksDeleteFargateProfileOperator(
task_id='delete_eks_fargate_profile',
+ cluster_name=CLUSTER_NAME,
fargate_profile_name=FARGATE_PROFILE_NAME,
)
# [END howto_operator_eks_delete_fargate_profile]
await_delete_fargate_profile = EksFargateProfileStateSensor(
task_id='wait_for_delete_fargate_profile',
+ cluster_name=CLUSTER_NAME,
fargate_profile_name=FARGATE_PROFILE_NAME,
target_state=FargateProfileStates.NONEXISTENT,
)
- delete_cluster = EksDeleteClusterOperator(task_id='delete_eks_cluster')
+ delete_cluster = EksDeleteClusterOperator(
+ task_id='delete_eks_cluster',
+ cluster_name=CLUSTER_NAME,
+ )
await_delete_cluster = EksClusterStateSensor(
task_id='wait_for_delete_cluster',
+ cluster_name=CLUSTER_NAME,
target_state=ClusterStates.NONEXISTENT,
)
diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py
index f19eec622f295..38d1bd1ad4c2f 100644
--- a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py
+++ b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py
@@ -42,19 +42,18 @@
with DAG(
- dag_id='example_eks_using_defaults_dag',
- default_args={'cluster_name': CLUSTER_NAME},
+ dag_id='example_eks_with_nodegroup_in_one_step',
schedule_interval=None,
start_date=datetime(2021, 1, 1),
- catchup=False,
- max_active_runs=1,
tags=['example'],
+ catchup=False,
) as dag:
# [START howto_operator_eks_create_cluster_with_nodegroup]
# Create an Amazon EKS cluster control plane and an EKS nodegroup compute platform in one step.
create_cluster_and_nodegroup = EksCreateClusterOperator(
task_id='create_eks_cluster_and_nodegroup',
+ cluster_name=CLUSTER_NAME,
nodegroup_name=NODEGROUP_NAME,
cluster_role_arn=ROLE_ARN,
nodegroup_role_arn=ROLE_ARN,
@@ -68,12 +67,14 @@
await_create_nodegroup = EksNodegroupStateSensor(
task_id='wait_for_create_nodegroup',
+ cluster_name=CLUSTER_NAME,
nodegroup_name=NODEGROUP_NAME,
target_state=NodegroupStates.ACTIVE,
)
start_pod = EksPodOperator(
task_id="run_pod",
+ cluster_name=CLUSTER_NAME,
pod_name="run_pod",
image="amazon/aws-cli:latest",
cmds=["sh", "-c", "echo Test Airflow; date"],
@@ -86,11 +87,16 @@
# [START howto_operator_eks_force_delete_cluster]
# An Amazon EKS cluster can not be deleted with attached resources such as nodegroups or Fargate profiles.
# Setting the `force` to `True` will delete any attached resources before deleting the cluster.
- delete_all = EksDeleteClusterOperator(task_id='delete_nodegroup_and_cluster', force_delete_compute=True)
+ delete_all = EksDeleteClusterOperator(
+ task_id='delete_nodegroup_and_cluster',
+ cluster_name=CLUSTER_NAME,
+ force_delete_compute=True,
+ )
# [END howto_operator_eks_force_delete_cluster]
await_delete_cluster = EksClusterStateSensor(
task_id='wait_for_delete_cluster',
+ cluster_name=CLUSTER_NAME,
target_state=ClusterStates.NONEXISTENT,
)
diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py
index 3ec6a3ac459a0..efeeb14e04bfb 100644
--- a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py
+++ b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py
@@ -45,48 +45,55 @@
with DAG(
- dag_id='example_eks_with_nodegroups_dag',
- default_args={'cluster_name': CLUSTER_NAME},
+ dag_id='example_eks_with_nodegroups',
schedule_interval=None,
start_date=datetime(2021, 1, 1),
- catchup=False,
- max_active_runs=1,
tags=['example'],
+ catchup=False,
) as dag:
# [START howto_operator_eks_create_cluster]
# Create an Amazon EKS Cluster control plane without attaching compute service.
create_cluster = EksCreateClusterOperator(
task_id='create_eks_cluster',
+ cluster_name=CLUSTER_NAME,
cluster_role_arn=ROLE_ARN,
resources_vpc_config=VPC_CONFIG,
compute=None,
)
# [END howto_operator_eks_create_cluster]
+ # [START howto_sensor_eks_cluster]
await_create_cluster = EksClusterStateSensor(
task_id='wait_for_create_cluster',
+ cluster_name=CLUSTER_NAME,
target_state=ClusterStates.ACTIVE,
)
+ # [END howto_sensor_eks_cluster]
# [START howto_operator_eks_create_nodegroup]
create_nodegroup = EksCreateNodegroupOperator(
task_id='create_eks_nodegroup',
+ cluster_name=CLUSTER_NAME,
nodegroup_name=NODEGROUP_NAME,
nodegroup_subnets=SUBNETS,
nodegroup_role_arn=ROLE_ARN,
)
# [END howto_operator_eks_create_nodegroup]
+ # [START howto_sensor_eks_nodegroup]
await_create_nodegroup = EksNodegroupStateSensor(
task_id='wait_for_create_nodegroup',
+ cluster_name=CLUSTER_NAME,
nodegroup_name=NODEGROUP_NAME,
target_state=NodegroupStates.ACTIVE,
)
+ # [END howto_sensor_eks_nodegroup]
# [START howto_operator_eks_pod_operator]
start_pod = EksPodOperator(
task_id="run_pod",
+ cluster_name=CLUSTER_NAME,
pod_name="run_pod",
image="amazon/aws-cli:latest",
cmds=["sh", "-c", "ls"],
@@ -99,22 +106,29 @@
# [START howto_operator_eks_delete_nodegroup]
delete_nodegroup = EksDeleteNodegroupOperator(
- task_id='delete_eks_nodegroup', nodegroup_name=NODEGROUP_NAME
+ task_id='delete_eks_nodegroup',
+ cluster_name=CLUSTER_NAME,
+ nodegroup_name=NODEGROUP_NAME,
)
# [END howto_operator_eks_delete_nodegroup]
await_delete_nodegroup = EksNodegroupStateSensor(
task_id='wait_for_delete_nodegroup',
+ cluster_name=CLUSTER_NAME,
nodegroup_name=NODEGROUP_NAME,
target_state=NodegroupStates.NONEXISTENT,
)
# [START howto_operator_eks_delete_cluster]
- delete_cluster = EksDeleteClusterOperator(task_id='delete_eks_cluster')
+ delete_cluster = EksDeleteClusterOperator(
+ task_id='delete_eks_cluster',
+ cluster_name=CLUSTER_NAME,
+ )
# [END howto_operator_eks_delete_cluster]
await_delete_cluster = EksClusterStateSensor(
task_id='wait_for_delete_cluster',
+ cluster_name=CLUSTER_NAME,
target_state=ClusterStates.NONEXISTENT,
)
diff --git a/airflow/providers/amazon/aws/example_dags/example_sns.py b/airflow/providers/amazon/aws/example_dags/example_sns.py
new file mode 100644
index 0000000000000..782156b14c3d3
--- /dev/null
+++ b/airflow/providers/amazon/aws/example_dags/example_sns.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from datetime import datetime
+from os import environ
+
+from airflow import DAG
+from airflow.providers.amazon.aws.operators.sns import SnsPublishOperator
+
+SNS_TOPIC_ARN = environ.get('SNS_TOPIC_ARN', 'arn:aws:sns:us-west-2:123456789012:dummy-topic-name')
+
+with DAG(
+ dag_id='example_sns',
+ schedule_interval=None,
+ start_date=datetime(2021, 1, 1),
+ tags=['example'],
+ catchup=False,
+) as dag:
+
+ # [START howto_operator_sns_publish_operator]
+ publish = SnsPublishOperator(
+ task_id='publish_message',
+ target_arn=SNS_TOPIC_ARN,
+ message='This is a sample message sent to SNS via an Apache Airflow DAG task.',
+ )
+ # [END howto_operator_sns_publish_operator]
diff --git a/airflow/providers/amazon/aws/operators/sns.py b/airflow/providers/amazon/aws/operators/sns.py
index 48a436b020745..e916798d03386 100644
--- a/airflow/providers/amazon/aws/operators/sns.py
+++ b/airflow/providers/amazon/aws/operators/sns.py
@@ -30,6 +30,10 @@ class SnsPublishOperator(BaseOperator):
"""
Publish a message to Amazon SNS.
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:SnsPublishOperator`
+
:param aws_conn_id: aws connection to use
:param target_arn: either a TopicArn or an EndpointArn
:param message: the default message you want to send (templated)
diff --git a/airflow/providers/amazon/aws/sensors/eks.py b/airflow/providers/amazon/aws/sensors/eks.py
index 7f639b684103f..92ed55da4d31e 100644
--- a/airflow/providers/amazon/aws/sensors/eks.py
+++ b/airflow/providers/amazon/aws/sensors/eks.py
@@ -60,6 +60,10 @@ class EksClusterStateSensor(BaseSensorOperator):
"""
Check the state of an Amazon EKS Cluster until it reaches the target state or another terminal state.
+ .. seealso::
+ For more information on how to use this sensor, take a look at the guide:
+ :ref:`howto/sensor:EksClusterStateSensor`
+
:param cluster_name: The name of the Cluster to watch. (templated)
:param target_state: Target state of the Cluster. (templated)
:param region: Which AWS region the connection should use. (templated)
@@ -116,6 +120,10 @@ class EksFargateProfileStateSensor(BaseSensorOperator):
"""
Check the state of an AWS Fargate profile until it reaches the target state or another terminal state.
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/sensor:EksFargateProfileStateSensor`
+
:param cluster_name: The name of the Cluster which the AWS Fargate profile is attached to. (templated)
:param fargate_profile_name: The name of the Fargate profile to watch. (templated)
:param target_state: Target state of the Fargate profile. (templated)
@@ -183,6 +191,10 @@ class EksNodegroupStateSensor(BaseSensorOperator):
"""
Check the state of an EKS managed node group until it reaches the target state or another terminal state.
+ .. seealso::
+ For more information on how to use this sensor, take a look at the guide:
+ :ref:`howto/sensor:EksNodegroupStateSensor`
+
:param cluster_name: The name of the Cluster which the Nodegroup is attached to. (templated)
:param nodegroup_name: The name of the Nodegroup to watch. (templated)
:param target_state: Target state of the Nodegroup. (templated)
diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml
index f6887e0477e6c..962e48b10541c 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -127,6 +127,8 @@ integrations:
- integration-name: Amazon Simple Notification Service (SNS)
external-doc-url: https://aws.amazon.com/sns/
logo: /integration-logos/aws/Amazon-Simple-Notification-Service-SNS_light-bg@4x.png
+ how-to-guide:
+ - /docs/apache-airflow-providers-amazon/operators/sns.rst
tags: [aws]
- integration-name: Amazon Simple Queue Service (SQS)
external-doc-url: https://aws.amazon.com/sqs/
diff --git a/airflow/providers/apache/pig/example_dags/example_pig.py b/airflow/providers/apache/pig/example_dags/example_pig.py
index e3ab899e2f5be..ed1b34ab0c8a4 100644
--- a/airflow/providers/apache/pig/example_dags/example_pig.py
+++ b/airflow/providers/apache/pig/example_dags/example_pig.py
@@ -30,9 +30,11 @@
tags=['example'],
)
+# [START create_pig]
run_this = PigOperator(
task_id="run_example_pig_script",
pig="ls /;",
pig_opts="-x local",
dag=dag,
)
+# [END create_pig]
diff --git a/airflow/providers/apache/pig/provider.yaml b/airflow/providers/apache/pig/provider.yaml
index ae1cf0870dc5d..434e307b96ff1 100644
--- a/airflow/providers/apache/pig/provider.yaml
+++ b/airflow/providers/apache/pig/provider.yaml
@@ -19,7 +19,7 @@
package-name: apache-airflow-providers-apache-pig
name: Apache Pig
description: |
- `Apache Pig `__
+ `Apache Pig `__
versions:
- 2.0.1
@@ -33,6 +33,8 @@ additional-dependencies:
integrations:
- integration-name: Apache Pig
external-doc-url: https://pig.apache.org/
+ how-to-guide:
+ - /docs/apache-airflow-providers-apache-pig/operators.rst
logo: /integration-logos/apache/pig.png
tags: [apache]
@@ -46,7 +48,8 @@ hooks:
python-modules:
- airflow.providers.apache.pig.hooks.pig
-hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+
+hook-class-names:
+ # deprecated - to be removed after providers add dependency on Airflow 2.2.0+
- airflow.providers.apache.pig.hooks.pig.PigCliHook
connection-types:
diff --git a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
index 7e6498cefd5ca..499b331da19c5 100644
--- a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
@@ -90,6 +90,8 @@ class GCSToGCSOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
+ :param source_object_required: Whether you want to raise an exception when the source object
+ doesn't exist. It doesn't have any effect when the source objects are folders or patterns.
:Example:
@@ -190,6 +192,7 @@ def __init__(
maximum_modified_time=None,
is_older_than=None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+ source_object_required=False,
**kwargs,
):
super().__init__(**kwargs)
@@ -216,6 +219,7 @@ def __init__(
self.maximum_modified_time = maximum_modified_time
self.is_older_than = is_older_than
self.impersonation_chain = impersonation_chain
+ self.source_object_required = source_object_required
def execute(self, context: 'Context'):
@@ -313,6 +317,11 @@ def _copy_source_without_wildcard(self, hook, prefix):
self._copy_single_object(
hook=hook, source_object=prefix, destination_object=self.destination_object
)
+ elif self.source_object_required:
+ msg = f"{prefix} does not exist in bucket {self.source_bucket}"
+ self.log.warning(msg)
+ raise AirflowException(msg)
+
for source_obj in objects:
if self.destination_object is None:
destination_object = source_obj
diff --git a/airflow/providers/zendesk/example_dags/example_zendesk_custom_get.py b/airflow/providers/zendesk/example_dags/example_zendesk_custom_get.py
index 8332ba86d226b..084ebf4f2860d 100644
--- a/airflow/providers/zendesk/example_dags/example_zendesk_custom_get.py
+++ b/airflow/providers/zendesk/example_dags/example_zendesk_custom_get.py
@@ -19,11 +19,12 @@
from typing import Dict, List
from airflow import DAG
-from airflow.operators.python import PythonOperator
+from airflow.decorators import task
from airflow.providers.zendesk.hooks.zendesk import ZendeskHook
-def zendesk_custom_get_request() -> List[Dict]:
+@task
+def fetch_organizations() -> List[Dict]:
hook = ZendeskHook()
response = hook.get(
url="https://yourdomain.zendesk.com/api/v2/organizations.json",
@@ -37,7 +38,4 @@ def zendesk_custom_get_request() -> List[Dict]:
start_date=datetime(2021, 1, 1),
catchup=False,
) as dag:
- fetch_organizations = PythonOperator(
- task_id="trigger_zendesk_hook",
- python_callable=zendesk_custom_get_request,
- )
+ fetch_organizations()
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index d6abda7c74899..017f2276964ca 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -16,6 +16,7 @@
# under the License.
"""Serialized DAG and BaseOperator"""
+import contextlib
import datetime
import enum
import logging
@@ -168,7 +169,7 @@ def _decode_timetable(var: Dict[str, Any]) -> Timetable:
return timetable_class.deserialize(var[Encoding.VAR])
-class _XcomRef(NamedTuple):
+class _XComRef(NamedTuple):
"""
Used to store info needed to create XComArg when deserializing MappedOperator.
@@ -497,8 +498,8 @@ def _serialize_xcomarg(cls, arg: XComArg) -> dict:
return {"key": arg.key, "task_id": arg.operator.task_id}
@classmethod
- def _deserialize_xcomref(cls, encoded: dict) -> _XcomRef:
- return _XcomRef(key=encoded['key'], task_id=encoded['task_id'])
+ def _deserialize_xcomref(cls, encoded: dict) -> _XComRef:
+ return _XComRef(key=encoded['key'], task_id=encoded['task_id'])
class DependencyDetector:
@@ -566,9 +567,19 @@ def task_type(self, task_type: str):
@classmethod
def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]:
-
stock_deps = op.deps is MappedOperator.DEFAULT_DEPS
serialize_op = cls._serialize_node(op, include_deps=not stock_deps)
+
+ # Simplify op_kwargs format. It must be a dict, so we flatten it.
+ with contextlib.suppress(KeyError):
+ op_kwargs = serialize_op["mapped_kwargs"]["op_kwargs"]
+ assert op_kwargs[Encoding.TYPE] == DAT.DICT
+ serialize_op["mapped_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR]
+ with contextlib.suppress(KeyError):
+ op_kwargs = serialize_op["partial_kwargs"]["op_kwargs"]
+ assert op_kwargs[Encoding.TYPE] == DAT.DICT
+ serialize_op["partial_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR]
+
# It must be a class at this point for it to work, not a string
assert isinstance(op.operator_class, type)
serialize_op['_task_type'] = op.operator_class.__name__
@@ -715,7 +726,13 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Union[BaseOperator,
elif k == "params":
v = cls._deserialize_params_dict(v)
elif k in ("mapped_kwargs", "partial_kwargs"):
+ if "op_kwargs" not in v:
+ op_kwargs: Optional[dict] = None
+ else:
+ op_kwargs = {arg: cls._deserialize(value) for arg, value in v.pop("op_kwargs").items()}
v = {arg: cls._deserialize(value) for arg, value in v.items()}
+ if op_kwargs is not None:
+ v["op_kwargs"] = op_kwargs
elif k in cls._decorated_fields or k not in op.get_serialized_fields():
v = cls._deserialize(v)
# else use v as it is
@@ -1002,7 +1019,7 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG':
if isinstance(task, MappedOperator):
for d in (task.mapped_kwargs, task.partial_kwargs):
for k, v in d.items():
- if not isinstance(v, _XcomRef):
+ if not isinstance(v, _XComRef):
continue
d[k] = XComArg(operator=dag.get_task(v.task_id), key=v.key)
diff --git a/airflow/task/task_runner/cgroup_task_runner.py b/airflow/task/task_runner/cgroup_task_runner.py
index 15075de64c567..d6c6e53abf935 100644
--- a/airflow/task/task_runner/cgroup_task_runner.py
+++ b/airflow/task/task_runner/cgroup_task_runner.py
@@ -146,11 +146,11 @@ def start(self):
self._mem_mb_limit = resources.ram.qty
# Create the memory cgroup
- mem_cgroup_node = self._create_cgroup(self.mem_cgroup_name)
+ self.mem_cgroup_node = self._create_cgroup(self.mem_cgroup_name)
self._created_mem_cgroup = True
if self._mem_mb_limit > 0:
self.log.debug("Setting %s with %s MB of memory", self.mem_cgroup_name, self._mem_mb_limit)
- mem_cgroup_node.controller.limit_in_bytes = self._mem_mb_limit * 1024 * 1024
+ self.mem_cgroup_node.controller.limit_in_bytes = self._mem_mb_limit * 1024 * 1024
# Create the CPU cgroup
cpu_cgroup_node = self._create_cgroup(self.cpu_cgroup_name)
@@ -185,11 +185,32 @@ def terminate(self):
if self.process and psutil.pid_exists(self.process.pid):
reap_process_group(self.process.pid, self.log)
+ def _log_memory_usage(self, mem_cgroup_node):
+ def byte_to_gb(num_bytes, precision=2):
+ return round(num_bytes / (1024 * 1024 * 1024), precision)
+
+ with open(mem_cgroup_node.full_path + '/memory.max_usage_in_bytes') as f:
+ max_usage_in_bytes = int(f.read().strip())
+
+ used_gb = byte_to_gb(max_usage_in_bytes)
+ limit_gb = byte_to_gb(mem_cgroup_node.controller.limit_in_bytes)
+
+ self.log.info(
+ "Memory max usage of the task is %s GB, while the memory limit is %s GB", used_gb, limit_gb
+ )
+
+ if max_usage_in_bytes >= mem_cgroup_node.controller.limit_in_bytes:
+ self.log.info(
+ "This task has reached the memory limit allocated by Airflow worker. "
+ "If it failed, try to optimize the task or reserve more memory."
+ )
+
def on_finish(self):
# Let the OOM watcher thread know we're done to avoid false OOM alarms
self._finished_running = True
# Clean up the cgroups
if self._created_mem_cgroup:
+ self._log_memory_usage(self.mem_cgroup_node)
self._delete_cgroup(self.mem_cgroup_name)
if self._created_cpu_cgroup:
self._delete_cgroup(self.cpu_cgroup_name)
diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
index 5bdbf0c14f8dd..a85ab8971cd1c 100644
--- a/airflow/utils/helpers.py
+++ b/airflow/utils/helpers.py
@@ -35,7 +35,6 @@
Tuple,
TypeVar,
)
-from urllib import parse
from airflow.configuration import conf
from airflow.exceptions import AirflowException
@@ -258,8 +257,7 @@ def build_airflow_url_with_query(query: Dict[str, Any]) -> str:
import flask
view = conf.get('webserver', 'dag_default_view').lower()
- url = flask.url_for(f"Airflow.{view}")
- return f"{url}?{parse.urlencode(query)}"
+ return flask.url_for(f"Airflow.{view}", **query)
# The 'template' argument is typed as Any because the jinja2.Template is too
diff --git a/airflow/utils/operator_resources.py b/airflow/utils/operator_resources.py
index 8c3263247faa9..28710cddd42cc 100644
--- a/airflow/utils/operator_resources.py
+++ b/airflow/utils/operator_resources.py
@@ -50,6 +50,8 @@ def __init__(self, name, units_str, qty):
self._qty = qty
def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
return self.__dict__ == other.__dict__
def __repr__(self):
@@ -126,6 +128,8 @@ def __init__(
self.gpus = GpuResource(gpus)
def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
return self.__dict__ == other.__dict__
def __repr__(self):
diff --git a/airflow/www/decorators.py b/airflow/www/decorators.py
index 080fe682991c2..47a9847c2eecf 100644
--- a/airflow/www/decorators.py
+++ b/airflow/www/decorators.py
@@ -20,6 +20,7 @@
import gzip
import logging
from io import BytesIO as IO
+from itertools import chain
from typing import Callable, TypeVar, cast
import pendulum
@@ -48,13 +49,19 @@ def wrapper(*args, **kwargs):
user = g.user.username
fields_skip_logging = {'csrf_token', '_csrf_token'}
+ log_fields = {
+ k: v
+ for k, v in chain(request.values.items(), request.view_args.items())
+ if k not in fields_skip_logging
+ }
+
log = Log(
event=f.__name__,
task_instance=None,
owner=user,
- extra=str([(k, v) for k, v in request.values.items() if k not in fields_skip_logging]),
- task_id=request.values.get('task_id'),
- dag_id=request.values.get('dag_id'),
+ extra=str([(k, log_fields[k]) for k in log_fields]),
+ task_id=log_fields.get('task_id'),
+ dag_id=log_fields.get('dag_id'),
)
if 'execution_date' in request.values:
diff --git a/airflow/www/fab_security/sqla/models.py b/airflow/www/fab_security/sqla/models.py
index 69853722d59b7..93a95a45c5e21 100644
--- a/airflow/www/fab_security/sqla/models.py
+++ b/airflow/www/fab_security/sqla/models.py
@@ -37,7 +37,6 @@
)
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import backref, relationship
-from sqlalchemy.orm.relationships import foreign
"""
Compatibility note: The models in this file are duplicated from Flask AppBuilder.
@@ -140,17 +139,13 @@ class Permission(Model):
action_id = Column("permission_id", Integer, ForeignKey("ab_permission.id"))
action = relationship(
"Action",
- primaryjoin=action_id == foreign(Action.id),
uselist=False,
- backref="permission",
lazy="joined",
)
resource_id = Column("view_menu_id", Integer, ForeignKey("ab_view_menu.id"))
resource = relationship(
"Resource",
- primaryjoin=resource_id == foreign(Resource.id),
uselist=False,
- backref="permission",
lazy="joined",
)
diff --git a/airflow/www/templates/airflow/dag.html b/airflow/www/templates/airflow/dag.html
index e57a1f1fc8b3b..0f4b967373da5 100644
--- a/airflow/www/templates/airflow/dag.html
+++ b/airflow/www/templates/airflow/dag.html
@@ -110,9 +110,9 @@
-
+
Zoom into Sub DAG
@@ -402,7 +402,7 @@
-
+
account_tree
Graph
diff --git a/airflow/www/templates/airflow/dags.html b/airflow/www/templates/airflow/dags.html
index c748d329a2193..41f63bf6b8ad1 100644
--- a/airflow/www/templates/airflow/dags.html
+++ b/airflow/www/templates/airflow/dags.html
@@ -31,7 +31,7 @@
-
+
@@ -293,9 +293,9 @@ {{ page_title }}
account_tree
Graph
-
- nature
- Tree
+
+ grid_on
+ Grid
more_horiz
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 18eb1e7f15dd8..d9ec10f8b8887 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -356,8 +356,10 @@ def dag_link(attr):
"""Generates a URL to the Graph view for a Dag."""
dag_id = attr.get('dag_id')
execution_date = attr.get('execution_date')
+ if not dag_id:
+ return Markup('None')
url = url_for('Airflow.graph', dag_id=dag_id, execution_date=execution_date)
- return Markup('
{}').format(url, dag_id) if dag_id else Markup('None')
+ return Markup('
{}').format(url, dag_id)
def dag_run_link(attr):
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 678d1cc4db3f1..7ff2c74c9b997 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1053,15 +1053,24 @@ def last_dagruns(self, session=None):
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE),
]
)
+ def legacy_code(self):
+ """Redirect from url param."""
+ return redirect(url_for('Airflow.code', **request.args))
+
+ @expose('/dags/
/code')
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE),
+ ]
+ )
@provide_session
- def code(self, session=None):
+ def code(self, dag_id, session=None):
"""Dag Code."""
all_errors = ""
dag_orm = None
- dag_id = None
try:
- dag_id = request.args.get('dag_id')
dag_orm = DagModel.get_dagmodel(dag_id, session=session)
code = DagCode.get_code_by_fileloc(dag_orm.fileloc)
html_code = Markup(highlight(code, lexers.PythonLexer(), HtmlFormatter(linenos=True)))
@@ -1094,10 +1103,20 @@ def code(self, session=None):
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
]
)
+ def legacy_dag_details(self):
+ """Redirect from url param."""
+ return redirect(url_for('Airflow.dag_details', **request.args))
+
+ @expose('/dags//details')
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
+ ]
+ )
@provide_session
- def dag_details(self, session=None):
+ def dag_details(self, dag_id, session=None):
"""Get Dag details."""
- dag_id = request.args.get('dag_id')
dag = current_app.dag_bag.get_dag(dag_id)
dag_model = DagModel.get_dagmodel(dag_id)
@@ -2300,6 +2319,34 @@ def success(self):
State.SUCCESS,
)
+ @expose('/dags/')
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG),
+ ]
+ )
+ @gzipped
+ @action_logging
+ def dag(self, dag_id):
+ """Redirect to default DAG view."""
+ return redirect(url_for('Airflow.grid', dag_id=dag_id, **request.args))
+
+ @expose('/legacy_tree')
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG),
+ ]
+ )
+ @gzipped
+ @action_logging
+ def legacy_tree(self):
+ """Redirect to the replacement - grid view."""
+ return redirect(url_for('Airflow.grid', **request.args))
+
@expose('/tree')
@auth.has_access(
[
@@ -2310,10 +2357,23 @@ def success(self):
)
@gzipped
@action_logging
+ def tree(self):
+ """Redirect to the replacement - grid view. Kept for backwards compatibility."""
+ return redirect(url_for('Airflow.grid', **request.args))
+
+ @expose('/dags//grid')
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG),
+ ]
+ )
+ @gzipped
+ @action_logging
@provide_session
- def tree(self, session=None):
- """Get Dag as tree."""
- dag_id = request.args.get('dag_id')
+ def grid(self, dag_id, session=None):
+ """Get Dag's grid view."""
dag = current_app.dag_bag.get_dag(dag_id)
dag_model = DagModel.get_dagmodel(dag_id)
if not dag:
@@ -2398,8 +2458,21 @@ def tree(self, session=None):
)
@gzipped
@action_logging
+ def legacy_calendar(self):
+ """Redirect from url param."""
+ return redirect(url_for('Airflow.calendar', **request.args))
+
+ @expose('/dags//calendar')
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ ]
+ )
+ @gzipped
+ @action_logging
@provide_session
- def calendar(self, session=None):
+ def calendar(self, dag_id, session=None):
"""Get DAG runs as calendar"""
def _convert_to_date(session, column):
@@ -2409,7 +2482,6 @@ def _convert_to_date(session, column):
else:
return func.date(column)
- dag_id = request.args.get('dag_id')
dag = current_app.dag_bag.get_dag(dag_id)
dag_model = DagModel.get_dagmodel(dag_id)
if not dag:
@@ -2475,10 +2547,23 @@ def _convert_to_date(session, column):
)
@gzipped
@action_logging
+ def legacy_graph(self):
+ """Redirect from url param."""
+ return redirect(url_for('Airflow.graph', **request.args))
+
+ @expose('/dags//graph')
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG),
+ ]
+ )
+ @gzipped
+ @action_logging
@provide_session
- def graph(self, session=None):
+ def graph(self, dag_id, session=None):
"""Get DAG as Graph."""
- dag_id = request.args.get('dag_id')
dag = current_app.dag_bag.get_dag(dag_id)
dag_model = DagModel.get_dagmodel(dag_id)
if not dag:
@@ -2568,11 +2653,22 @@ class GraphForm(DateTimeWithNumRunsWithDagRunsForm):
]
)
@action_logging
+ def legacy_duration(self):
+ """Redirect from url param."""
+ return redirect(url_for('Airflow.duration', **request.args))
+
+ @expose('/dags//duration')
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ ]
+ )
+ @action_logging
@provide_session
- def duration(self, session=None):
+ def duration(self, dag_id, session=None):
"""Get Dag as duration graph."""
default_dag_run = conf.getint('webserver', 'default_dag_run_display_number')
- dag_id = request.args.get('dag_id')
dag_model = DagModel.get_dagmodel(dag_id)
dag: Optional[DAG] = current_app.dag_bag.get_dag(dag_id)
@@ -2710,11 +2806,22 @@ def duration(self, session=None):
]
)
@action_logging
+ def legacy_tries(self):
+ """Redirect from url param."""
+ return redirect(url_for('Airflow.tries', **request.args))
+
+ @expose('/dags//tries')
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ ]
+ )
+ @action_logging
@provide_session
- def tries(self, session=None):
+ def tries(self, dag_id, session=None):
"""Shows all tries."""
default_dag_run = conf.getint('webserver', 'default_dag_run_display_number')
- dag_id = request.args.get('dag_id')
dag = current_app.dag_bag.get_dag(dag_id)
dag_model = DagModel.get_dagmodel(dag_id)
base_date = request.args.get('base_date')
@@ -2787,11 +2894,22 @@ def tries(self, session=None):
]
)
@action_logging
+ def legacy_landing_times(self):
+ """Redirect from url param."""
+ return redirect(url_for('Airflow.landing_times', **request.args))
+
+ @expose('/dags//landing-times')
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ ]
+ )
+ @action_logging
@provide_session
- def landing_times(self, session=None):
+ def landing_times(self, dag_id, session=None):
"""Shows landing times."""
default_dag_run = conf.getint('webserver', 'default_dag_run_display_number')
- dag_id = request.args.get('dag_id')
dag: DAG = current_app.dag_bag.get_dag(dag_id)
dag_model = DagModel.get_dagmodel(dag_id)
base_date = request.args.get('base_date')
@@ -2892,10 +3010,21 @@ def paused(self):
]
)
@action_logging
+ def legacy_gantt(self):
+ """Redirect from url param."""
+ return redirect(url_for('Airflow.gantt', **request.args))
+
+ @expose('/dags//gantt')
+ @auth.has_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ ]
+ )
+ @action_logging
@provide_session
- def gantt(self, session=None):
+ def gantt(self, dag_id, session=None):
"""Show GANTT chart."""
- dag_id = request.args.get('dag_id')
dag = current_app.dag_bag.get_dag(dag_id)
dag_model = DagModel.get_dagmodel(dag_id)
diff --git a/docs/README.rst b/docs/README.rst
index cc951484e1a54..2beb98379abc9 100644
--- a/docs/README.rst
+++ b/docs/README.rst
@@ -68,13 +68,17 @@ You can also see all the available arguments via ``--help``.
Running the Docs Locally
------------------------
-Once you have built the documentation run the following command from the root directory:
+Once you have built the documentation run the following command from the root directory,
+You need to have Python installed to run the command:
.. code-block:: bash
docs/start_doc_server.sh
-Then, view your docs at ``localhost:8000``
+
+Then, view your docs at ``localhost:8000``, if you are using a virtual machine e.g WSL2,
+you need to find the WSL2 machine IP address and in your browser replace “0.0.0.0” with it
+``http://n.n.n.n:8000``, where n.n.n.n will be the IP of the WSL2.
Troubleshooting
---------------
diff --git a/docs/apache-airflow-providers-amazon/operators/eks.rst b/docs/apache-airflow-providers-amazon/operators/eks.rst
index e2b856de91dd7..d6cd5ad791973 100644
--- a/docs/apache-airflow-providers-amazon/operators/eks.rst
+++ b/docs/apache-airflow-providers-amazon/operators/eks.rst
@@ -34,6 +34,21 @@ Prerequisite Tasks
Manage Amazon EKS Clusters
^^^^^^^^^^^^^^^^^^^^^^^^^^
+.. _howto/sensor:EksClusterStateSensor:
+
+Amazon EKS Cluster State Sensor
+"""""""""""""""""""""""""""""""
+
+To check the state of an Amazon EKS Cluster until it reaches the target state or another terminal
+state you can use :class:`~airflow.providers.amazon.aws.sensors.eks.EksClusterStateSensor`.
+
+.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_eks_cluster]
+ :end-before: [END howto_sensor_eks_cluster]
+
+
.. _howto/operator:EksCreateClusterOperator:
Create an Amazon EKS Cluster
@@ -48,6 +63,7 @@ Note: An AWS IAM role with the following permissions is required:
.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py
:language: python
+ :dedent: 4
:start-after: [START howto_operator_eks_create_cluster]
:end-before: [END howto_operator_eks_create_cluster]
@@ -61,6 +77,7 @@ To delete an existing Amazon EKS Cluster you can use
.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py
:language: python
+ :dedent: 4
:start-after: [START howto_operator_eks_delete_cluster]
:end-before: [END howto_operator_eks_delete_cluster]
@@ -70,6 +87,7 @@ Note: If the cluster has any attached resources, such as an Amazon EKS Nodegroup
.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py
:language: python
+ :dedent: 4
:start-after: [START howto_operator_eks_force_delete_cluster]
:end-before: [END howto_operator_eks_force_delete_cluster]
@@ -77,6 +95,20 @@ Note: If the cluster has any attached resources, such as an Amazon EKS Nodegroup
Manage Amazon EKS Managed Nodegroups
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+.. _howto/sensor:EksNodegroupStateSensor:
+
+Amazon EKS Managed Nodegroup State Sensor
+"""""""""""""""""""""""""""""""""""""""""
+
+To check the state of an Amazon EKS managed node group until it reaches the target state or another terminal
+state you can use :class:`~airflow.providers.amazon.aws.sensors.eks.EksNodegroupStateSensor`.
+
+.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_eks_nodegroup]
+ :end-before: [END howto_sensor_eks_nodegroup]
+
.. _howto/operator:EksCreateNodegroupOperator:
Create an Amazon EKS Managed NodeGroup
@@ -92,6 +124,7 @@ Note: An AWS IAM role with the following permissions is required:
.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py
:language: python
+ :dedent: 4
:start-after: [START howto_operator_eks_create_nodegroup]
:end-before: [END howto_operator_eks_create_nodegroup]
@@ -105,6 +138,7 @@ To delete an existing Amazon EKS Managed Nodegroup you can use
.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py
:language: python
+ :dedent: 4
:start-after: [START howto_operator_eks_delete_nodegroup]
:end-before: [END howto_operator_eks_delete_nodegroup]
@@ -124,6 +158,7 @@ Note: An AWS IAM role with the following permissions is required:
.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py
:language: python
+ :dedent: 4
:start-after: [START howto_operator_eks_create_cluster_with_nodegroup]
:end-before: [END howto_operator_eks_create_cluster_with_nodegroup]
@@ -142,12 +177,28 @@ Note: An AWS IAM role with the following permissions is required:
.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py
:language: python
+ :dedent: 4
:start-after: [START howto_operator_eks_create_cluster_with_fargate_profile]
:end-before: [END howto_operator_eks_create_cluster_with_fargate_profile]
Manage AWS Fargate Profiles
^^^^^^^^^^^^^^^^^^^^^^^^^^^
+.. _howto/sensor:EksFargateProfileStateSensor:
+
+AWS Fargate Profile State Sensor
+""""""""""""""""""""""""""""""""
+
+To check the state of an AWS Fargate profile until it reaches the target state or another terminal
+state you can use :class:`~airflow.providers.amazon.aws.sensors.eks.EksFargateProfileSensor`.
+
+.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_eks_fargate]
+ :end-before: [END howto_sensor_eks_fargate]
+
+
.. _howto/operator:EksCreateFargateProfileOperator:
Create an AWS Fargate Profile
@@ -163,6 +214,7 @@ Note: An AWS IAM role with the following permissions is required:
.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py
:language: python
+ :dedent: 4
:start-after: [START howto_operator_eks_create_fargate_profile]
:end-before: [END howto_operator_eks_create_fargate_profile]
@@ -176,6 +228,7 @@ To delete an existing AWS Fargate Profile you can use
.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py
:language: python
+ :dedent: 4
:start-after: [START howto_operator_eks_delete_fargate_profile]
:end-before: [END howto_operator_eks_delete_fargate_profile]
@@ -191,6 +244,7 @@ Note: An Amazon EKS Cluster with underlying compute infrastructure is required.
.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py
:language: python
+ :dedent: 4
:start-after: [START howto_operator_eks_pod_operator]
:end-before: [END howto_operator_eks_pod_operator]
diff --git a/docs/apache-airflow-providers-amazon/operators/sns.rst b/docs/apache-airflow-providers-amazon/operators/sns.rst
new file mode 100644
index 0000000000000..1853bf27ae9dc
--- /dev/null
+++ b/docs/apache-airflow-providers-amazon/operators/sns.rst
@@ -0,0 +1,59 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+Amazon Simple Notification Service (SNS) Operators
+==================================================
+
+`Amazon Simple Notification Service (Amazon SNS) `__ is a managed
+service that provides message delivery from publishers to subscribers (also known as producers
+and consumers). Publishers communicate asynchronously with subscribers by sending messages to
+a topic, which is a logical access point and communication channel. Clients can subscribe to the
+SNS topic and receive published messages using a supported endpoint type, such as Amazon Kinesis
+Data Firehose, Amazon SQS, AWS Lambda, HTTP, email, mobile push notifications, and mobile text
+messages (SMS).
+
+Airflow provides an operator to publish messages to an SNS Topic.
+
+Prerequisite Tasks
+^^^^^^^^^^^^^^^^^^
+
+.. include::/operators/_partials/prerequisite_tasks.rst
+
+
+.. _howto/operator:SnsPublishOperator:
+
+Publish A Message To An Existing SNS Topic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+To publish a message to an Amazon SNS Topic you can use
+:class:`~airflow.providers.amazon.aws.operators.sns.SnsPublishOperator`.
+
+
+.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_sns.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sns_publish_operator]
+ :end-before: [END howto_operator_sns_publish_operator]
+
+
+Reference
+^^^^^^^^^
+
+For further information, look at:
+
+* `Boto3 Library Documentation for SNS `__
diff --git a/docs/apache-airflow-providers-apache-pig/index.rst b/docs/apache-airflow-providers-apache-pig/index.rst
index 490cfea2405da..4bd2b49184ce8 100644
--- a/docs/apache-airflow-providers-apache-pig/index.rst
+++ b/docs/apache-airflow-providers-apache-pig/index.rst
@@ -21,6 +21,12 @@
Content
-------
+.. toctree::
+ :maxdepth: 1
+ :caption: Guides
+
+ Operators
+
.. toctree::
:maxdepth: 1
:caption: References
diff --git a/docs/apache-airflow-providers-apache-pig/operators.rst b/docs/apache-airflow-providers-apache-pig/operators.rst
new file mode 100644
index 0000000000000..04e29e175b1e2
--- /dev/null
+++ b/docs/apache-airflow-providers-apache-pig/operators.rst
@@ -0,0 +1,32 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+
+Apache Pig Operators
+====================
+
+Apache Pig is a platform for analyzing large data sets that consists of a high-level language
+for expressing data analysis programs, coupled with infrastructure for evaluating these programs.
+Pig programs are amenable to substantial parallelization, which in turns enables them to handle very large data sets.
+
+use the PigOperator to execute a pig script
+
+.. exampleinclude:: /../../airflow/providers/apache/pig/example_dags/example_pig.py
+ :language: python
+ :start-after: [START create_pig]
+ :end-before: [END create_pig]
diff --git a/docs/apache-airflow-providers-postgres/connections/postgres.rst b/docs/apache-airflow-providers-postgres/connections/postgres.rst
index 542511baaa987..3ffbae7d2771c 100644
--- a/docs/apache-airflow-providers-postgres/connections/postgres.rst
+++ b/docs/apache-airflow-providers-postgres/connections/postgres.rst
@@ -58,7 +58,7 @@ Extra (optional)
* ``keepalives_idle`` - Controls the number of seconds of inactivity after which TCP
should send a keepalive message to the server.
* ``client_encoding``: specifies client encoding(character set) of the client connection.
- Refer to `Postgres supported character sets `_
+ Refer to `Postgres supported character sets `_
More details on all Postgres parameters supported can be found in
`Postgres documentation `_.
diff --git a/docs/apache-airflow/tutorial.rst b/docs/apache-airflow/tutorial.rst
index 02369de99294f..bd497970deb6f 100644
--- a/docs/apache-airflow/tutorial.rst
+++ b/docs/apache-airflow/tutorial.rst
@@ -413,7 +413,7 @@ Let's break this down into 2 steps: get data & merge data:
import requests
from airflow.decorators import task
- from airflow.hooks.postgres import PostgresHook
+ from airflow.providers.postgres.hooks.postgres import PostgresHook
@task
@@ -478,7 +478,7 @@ Lets look at our DAG:
import requests
from airflow.decorators import dag, task
- from airflow.hooks.postgres import PostgresHook
+ from airflow.providers.postgres.hooks.postgres import PostgresHook
@dag(
diff --git a/docs/start_doc_server.sh b/docs/start_doc_server.sh
index 28397c59ff524..ce67eb8d10fb2 100755
--- a/docs/start_doc_server.sh
+++ b/docs/start_doc_server.sh
@@ -20,5 +20,5 @@ DOCS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
readonly DOCS_DIR
(cd "${DOCS_DIR}"/_build || exit;
- python -m http.server 8000
+ python3 -m http.server 8000
)
diff --git a/tests/dag_processing/test_manager.py b/tests/dag_processing/test_manager.py
index 5ea21a216b176..2746e5963806f 100644
--- a/tests/dag_processing/test_manager.py
+++ b/tests/dag_processing/test_manager.py
@@ -45,17 +45,13 @@
DagParsingStat,
)
from airflow.dag_processing.processor import DagFileProcessorProcess
-from airflow.jobs.local_task_job import LocalTaskJob as LJ
-from airflow.models import DagBag, DagModel, TaskInstance as TI, errors
+from airflow.models import DagBag, DagModel, errors
from airflow.models.dagcode import DagCode
from airflow.models.serialized_dag import SerializedDagModel
-from airflow.models.taskinstance import SimpleTaskInstance
from airflow.utils import timezone
-from airflow.utils.callback_requests import CallbackRequest, TaskCallbackRequest
+from airflow.utils.callback_requests import CallbackRequest
from airflow.utils.net import get_hostname
from airflow.utils.session import create_session
-from airflow.utils.state import DagRunState, State
-from airflow.utils.types import DagRunType
from tests.core.test_logging_config import SETTINGS_FILE_VALID, settings_context
from tests.models import TEST_DAGS_FOLDER
from tests.test_utils.config import conf_vars
@@ -455,147 +451,6 @@ def test_recently_modified_file_is_parsed_with_mtime_mode(
> (freezed_base_time - manager.get_last_finish_time("file_1.py")).total_seconds()
)
- def test_find_zombies(self):
- manager = DagFileProcessorManager(
- dag_directory='directory',
- max_runs=1,
- processor_timeout=timedelta.max,
- signal_conn=MagicMock(),
- dag_ids=[],
- pickle_dags=False,
- async_mode=True,
- )
-
- dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False)
- with create_session() as session:
- session.query(LJ).delete()
- dag = dagbag.get_dag('example_branch_operator')
- dag.sync_to_db()
- task = dag.get_task(task_id='run_this_first')
-
- dag_run = dag.create_dagrun(
- state=DagRunState.RUNNING,
- execution_date=DEFAULT_DATE,
- run_type=DagRunType.SCHEDULED,
- session=session,
- )
-
- ti = TI(task, run_id=dag_run.run_id, state=State.RUNNING)
- local_job = LJ(ti)
- local_job.state = State.SHUTDOWN
-
- session.add(local_job)
- session.flush()
-
- ti.job_id = local_job.id
- session.add(ti)
- session.flush()
-
- manager._last_zombie_query_time = timezone.utcnow() - timedelta(
- seconds=manager._zombie_threshold_secs + 1
- )
- manager._find_zombies()
- requests = manager._callback_to_execute[dag.fileloc]
- assert 1 == len(requests)
- assert requests[0].full_filepath == dag.fileloc
- assert requests[0].msg == f"Detected {ti} as zombie"
- assert requests[0].is_failure_callback is True
- assert isinstance(requests[0].simple_task_instance, SimpleTaskInstance)
- assert ti.dag_id == requests[0].simple_task_instance.dag_id
- assert ti.task_id == requests[0].simple_task_instance.task_id
- assert ti.run_id == requests[0].simple_task_instance.run_id
-
- session.query(TI).delete()
- session.query(LJ).delete()
-
- @mock.patch('airflow.dag_processing.manager.DagFileProcessorProcess')
- def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_processor(
- self, mock_processor
- ):
- """
- Check that the same set of failure callback with zombies are passed to the dag
- file processors until the next zombie detection logic is invoked.
- """
- test_dag_path = TEST_DAG_FOLDER / 'test_example_bash_operator.py'
- with conf_vars({('scheduler', 'parsing_processes'): '1', ('core', 'load_examples'): 'False'}):
- dagbag = DagBag(test_dag_path, read_dags_from_db=False)
- with create_session() as session:
- session.query(LJ).delete()
- dag = dagbag.get_dag('test_example_bash_operator')
- dag.sync_to_db()
-
- dag_run = dag.create_dagrun(
- state=DagRunState.RUNNING,
- execution_date=DEFAULT_DATE,
- run_type=DagRunType.SCHEDULED,
- session=session,
- )
- task = dag.get_task(task_id='run_this_last')
-
- ti = TI(task, run_id=dag_run.run_id, state=State.RUNNING)
- local_job = LJ(ti)
- local_job.state = State.SHUTDOWN
- session.add(local_job)
- session.flush()
-
- # TODO: If there was an actual Relationship between TI and Job
- # we wouldn't need this extra commit
- session.add(ti)
- ti.job_id = local_job.id
- session.flush()
-
- expected_failure_callback_requests = [
- TaskCallbackRequest(
- full_filepath=dag.fileloc,
- simple_task_instance=SimpleTaskInstance(ti),
- msg="Message",
- )
- ]
-
- test_dag_path = TEST_DAG_FOLDER / 'test_example_bash_operator.py'
-
- child_pipe, parent_pipe = multiprocessing.Pipe()
- async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn')
-
- fake_processors = []
-
- def fake_processor_(*args, **kwargs):
- nonlocal fake_processors
- processor = FakeDagFileProcessorRunner._create_process(*args, **kwargs)
- fake_processors.append(processor)
- return processor
-
- mock_processor.side_effect = fake_processor_
-
- manager = DagFileProcessorManager(
- dag_directory=test_dag_path,
- max_runs=1,
- processor_timeout=timedelta.max,
- signal_conn=child_pipe,
- dag_ids=[],
- pickle_dags=False,
- async_mode=async_mode,
- )
-
- self.run_processor_manager_one_loop(manager, parent_pipe)
-
- if async_mode:
- # Once for initial parse, and then again for the add_callback_to_queue
- assert len(fake_processors) == 2
- assert fake_processors[0]._file_path == str(test_dag_path)
- assert fake_processors[0]._callback_requests == []
- else:
- assert len(fake_processors) == 1
-
- assert fake_processors[-1]._file_path == str(test_dag_path)
- callback_requests = fake_processors[-1]._callback_requests
- assert {zombie.simple_task_instance.key for zombie in expected_failure_callback_requests} == {
- result.simple_task_instance.key for result in callback_requests
- }
-
- child_pipe.close()
- parent_pipe.close()
-
@mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess.pid", new_callable=PropertyMock)
@mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess.kill")
def test_kill_timed_out_processors_kill(self, mock_kill, mock_pid):
diff --git a/tests/dags/test_mapped_taskflow.py b/tests/dags/test_mapped_taskflow.py
new file mode 100644
index 0000000000000..f21a9a5e8a42d
--- /dev/null
+++ b/tests/dags/test_mapped_taskflow.py
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow import DAG
+from airflow.utils.dates import days_ago
+
+with DAG(dag_id='test_mapped_taskflow', start_date=days_ago(2)) as dag:
+
+ @dag.task
+ def make_list():
+ return [1, 2, {'a': 'b'}]
+
+ @dag.task
+ def consumer(value):
+ print(repr(value))
+
+ consumer.map(value=make_list())
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 0c93b49e1fe00..ee94fde610d7a 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -17,7 +17,7 @@
# under the License.
import sys
from collections import namedtuple
-from datetime import date, timedelta
+from datetime import date, datetime, timedelta
from typing import Dict # noqa: F401 # This is used by annotation tests.
from typing import Tuple
@@ -490,7 +490,7 @@ def double(number: int):
assert isinstance(doubled_0, XComArg)
assert isinstance(doubled_0.operator, MappedOperator)
assert doubled_0.operator.task_id == "double"
- assert doubled_0.operator.mapped_kwargs == {"number": literal}
+ assert doubled_0.operator.mapped_kwargs == {"op_args": [], "op_kwargs": {"number": literal}}
assert doubled_1.operator.task_id == "double__1"
@@ -514,25 +514,68 @@ def test_partial_mapped_decorator() -> None:
def product(number: int, multiple: int):
return number * multiple
+ literal = [1, 2, 3]
+
with DAG('test_dag', start_date=DEFAULT_DATE) as dag:
- literal = [1, 2, 3]
- quadrupled = product.partial(task_id='times_4', multiple=3).map(number=literal)
+ quadrupled = product.partial(multiple=3).map(number=literal)
doubled = product.partial(multiple=2).map(number=literal)
trippled = product.partial(multiple=3).map(number=literal)
- product.partial(multiple=2)
+ product.partial(multiple=2) # No operator is actually created.
+
+ assert dag.task_dict == {
+ "product": quadrupled.operator,
+ "product__1": doubled.operator,
+ "product__2": trippled.operator,
+ }
assert isinstance(doubled, XComArg)
assert isinstance(doubled.operator, MappedOperator)
- assert doubled.operator.task_id == "product"
- assert doubled.operator.mapped_kwargs == {"number": literal}
- assert doubled.operator.partial_kwargs == {"task_id": "product", "multiple": 2}
+ assert doubled.operator.mapped_kwargs == {"op_args": [], "op_kwargs": {"number": literal}}
+ assert doubled.operator.partial_kwargs == {"op_args": [], "op_kwargs": {"multiple": 2}}
- assert trippled.operator.task_id == "product__1"
- assert trippled.operator.partial_kwargs == {"task_id": "product", "multiple": 3}
-
- assert quadrupled.operator.task_id == "times_4"
+ assert isinstance(trippled.operator, MappedOperator) # For type-checking on partial_kwargs.
+ assert trippled.operator.partial_kwargs == {"op_args": [], "op_kwargs": {"multiple": 3}}
assert doubled.operator is not trippled.operator
- assert [quadrupled.operator, doubled.operator, trippled.operator] == dag.tasks
+
+def test_mapped_decorator_unmap_merge_op_kwargs():
+ with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
+
+ @task_decorator
+ def task1():
+ ...
+
+ @task_decorator
+ def task2(arg1, arg2):
+ ...
+
+ task2.partial(arg1=1).map(arg2=task1())
+
+ unmapped = dag.get_task("task2").unmap()
+ assert set(unmapped.op_kwargs) == {"arg1", "arg2"}
+
+
+def test_mapped_decorator_unmap_converts_partial_kwargs():
+ with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
+
+ @task_decorator
+ def task1(arg):
+ ...
+
+ @task_decorator(retry_delay=30)
+ def task2(arg1, arg2):
+ ...
+
+ task2.partial(arg1=1).map(arg2=task1.map(arg=[1, 2]))
+
+ # Arguments to the task decorator are stored in partial_kwargs, and
+ # converted into their intended form after the task is unmapped.
+ mapped_task2 = dag.get_task("task2")
+ assert mapped_task2.partial_kwargs["retry_delay"] == 30
+ assert mapped_task2.unmap().retry_delay == timedelta(seconds=30)
+
+ mapped_task1 = dag.get_task("task1")
+ assert "retry_delay" not in mapped_task1.partial_kwargs
+ mapped_task1.unmap().retry_delay == timedelta(seconds=300) # Operator default.
diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py
index 0878f63ddffc6..40593d526a328 100644
--- a/tests/jobs/test_backfill_job.py
+++ b/tests/jobs/test_backfill_job.py
@@ -47,7 +47,13 @@
from airflow.utils.timeout import timeout
from airflow.utils.types import DagRunType
from tests.models import TEST_DAGS_FOLDER
-from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, set_default_pool_slots
+from tests.test_utils.db import (
+ clear_db_dags,
+ clear_db_pools,
+ clear_db_runs,
+ clear_db_xcom,
+ set_default_pool_slots,
+)
from tests.test_utils.mock_executor import MockExecutor
from tests.test_utils.timetables import cron_timetable
@@ -66,6 +72,7 @@ class TestBackfillJob:
def clean_db():
clear_db_dags()
clear_db_runs()
+ clear_db_xcom()
clear_db_pools()
@pytest.fixture(autouse=True)
@@ -1512,13 +1519,14 @@ def test_backfill_has_job_id(self):
job.run()
assert executor.job_id is not None
- def test_mapped_dag(self, dag_maker):
+ @pytest.mark.parametrize("dag_id", ["test_mapped_classic", "test_mapped_taskflow"])
+ def test_mapped_dag(self, dag_id):
"""End-to-end test of a simple mapped dag"""
# Use SequentialExecutor for more predictable test behaviour
from airflow.executors.sequential_executor import SequentialExecutor
- self.dagbag.process_file(str(TEST_DAGS_FOLDER / 'test_mapped_classic.py'))
- dag = self.dagbag.get_dag('test_mapped_classic')
+ self.dagbag.process_file(str(TEST_DAGS_FOLDER / f'{dag_id}.py'))
+ dag = self.dagbag.get_dag(dag_id)
# This needs a real executor to run, so that the `make_list` task can write out the TaskMap
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 707f587223f0a..845ffda016ac1 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -40,16 +40,17 @@
from airflow.executors.base_executor import BaseExecutor
from airflow.jobs.backfill_job import BackfillJob
from airflow.jobs.base_job import BaseJob
+from airflow.jobs.local_task_job import LocalTaskJob
from airflow.jobs.scheduler_job import SchedulerJob
from airflow.models import DAG, DagBag, DagModel, Pool, TaskInstance
from airflow.models.dagrun import DagRun
from airflow.models.serialized_dag import SerializedDagModel
-from airflow.models.taskinstance import TaskInstanceKey
+from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey
from airflow.operators.bash import BashOperator
from airflow.operators.dummy import DummyOperator
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.utils import timezone
-from airflow.utils.callback_requests import DagCallbackRequest
+from airflow.utils.callback_requests import DagCallbackRequest, TaskCallbackRequest
from airflow.utils.file import list_py_file_paths
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
@@ -3480,6 +3481,114 @@ def test_timeout_triggers(self, dag_maker):
assert ti1.next_method == "__fail__"
assert ti2.state == State.DEFERRED
+ def test_find_zombies_nothing(self):
+ with create_session() as session:
+ self.scheduler_job = SchedulerJob()
+ self.scheduler_job.processor_agent = mock.MagicMock()
+
+ self.scheduler_job._find_zombies(session=session)
+
+ self.scheduler_job.processor_agent.send_callback_to_execute.assert_not_called()
+
+ def test_find_zombies(self):
+ dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False)
+ with create_session() as session:
+ session.query(LocalTaskJob).delete()
+ dag = dagbag.get_dag('example_branch_operator')
+ dag.sync_to_db()
+ task = dag.get_task(task_id='run_this_first')
+
+ dag_run = dag.create_dagrun(
+ state=DagRunState.RUNNING,
+ execution_date=DEFAULT_DATE,
+ run_type=DagRunType.SCHEDULED,
+ session=session,
+ )
+
+ ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING)
+ local_job = LocalTaskJob(ti)
+ local_job.state = State.SHUTDOWN
+
+ session.add(local_job)
+ session.flush()
+
+ ti.job_id = local_job.id
+ session.add(ti)
+ session.flush()
+
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ self.scheduler_job.processor_agent = mock.MagicMock()
+
+ self.scheduler_job._find_zombies(session=session)
+
+ self.scheduler_job.processor_agent.send_callback_to_execute.assert_called_once()
+ requests = self.scheduler_job.processor_agent.send_callback_to_execute.call_args[0]
+ assert 1 == len(requests)
+ assert requests[0].full_filepath == dag.fileloc
+ assert requests[0].msg == f"Detected {ti} as zombie"
+ assert requests[0].is_failure_callback is True
+ assert isinstance(requests[0].simple_task_instance, SimpleTaskInstance)
+ assert ti.dag_id == requests[0].simple_task_instance.dag_id
+ assert ti.task_id == requests[0].simple_task_instance.task_id
+ assert ti.run_id == requests[0].simple_task_instance.run_id
+
+ session.query(TaskInstance).delete()
+ session.query(LocalTaskJob).delete()
+
+ def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_processor(self):
+ """
+ Check that the same set of failure callback with zombies are passed to the dag
+ file processors until the next zombie detection logic is invoked.
+ """
+ with conf_vars({('core', 'load_examples'): 'False'}):
+ dagbag = DagBag(
+ dag_folder=os.path.join(settings.DAGS_FOLDER, "test_example_bash_operator.py"),
+ read_dags_from_db=False,
+ )
+ session = settings.Session()
+ session.query(LocalTaskJob).delete()
+ dag = dagbag.get_dag('test_example_bash_operator')
+ dag.sync_to_db()
+
+ dag_run = dag.create_dagrun(
+ state=DagRunState.RUNNING,
+ execution_date=DEFAULT_DATE,
+ run_type=DagRunType.SCHEDULED,
+ session=session,
+ )
+ task = dag.get_task(task_id='run_this_last')
+
+ ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING)
+ local_job = LocalTaskJob(ti)
+ local_job.state = State.SHUTDOWN
+ session.add(local_job)
+ session.flush()
+
+ # TODO: If there was an actual Relationship between TI and Job
+ # we wouldn't need this extra commit
+ session.add(ti)
+ ti.job_id = local_job.id
+ session.flush()
+
+ expected_failure_callback_requests = [
+ TaskCallbackRequest(
+ full_filepath=dag.fileloc,
+ simple_task_instance=SimpleTaskInstance(ti),
+ msg="Message",
+ )
+ ]
+
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ self.scheduler_job.processor_agent = mock.MagicMock()
+
+ self.scheduler_job._find_zombies(session=session)
+
+ self.scheduler_job.processor_agent.send_callback_to_execute.assert_called_once()
+ callback_requests = self.scheduler_job.processor_agent.send_callback_to_execute.call_args[0]
+ assert {zombie.simple_task_instance.key for zombie in expected_failure_callback_requests} == {
+ result.simple_task_instance.key for result in callback_requests
+ }
+
@pytest.mark.xfail(reason="Work out where this goes")
def test_task_with_upstream_skip_process_task_instances():
diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py
index a23b50d42f893..dec0d378279a4 100644
--- a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py
@@ -546,3 +546,20 @@ def test_execute_wildcard_with_replace_flag_false_with_destination_object(self,
mock.call(DESTINATION_BUCKET, prefix="foo/bar", delimiter=""),
]
mock_hook.return_value.list.assert_has_calls(mock_calls)
+
+ @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook')
+ def test_execute_source_object_required_flag_true(self, mock_hook):
+ mock_hook.return_value.exists.return_value = False
+ operator = GCSToGCSOperator(
+ task_id=TASK_ID,
+ source_bucket=TEST_BUCKET,
+ source_objects=SOURCE_OBJECTS_SINGLE_FILE,
+ destination_bucket=DESTINATION_BUCKET,
+ destination_object=DESTINATION_OBJECT_PREFIX,
+ source_object_required=True,
+ )
+
+ with pytest.raises(
+ AirflowException, match=f"{SOURCE_OBJECTS_SINGLE_FILE} does not exist in bucket {TEST_BUCKET}"
+ ):
+ operator.execute(None)
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 447b1732a78b0..1e8d510fd7205 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1654,6 +1654,59 @@ def test_mapped_operator_xcomarg_serde():
assert xcom_arg.operator is serialized_dag.task_dict['op1']
+def test_mapped_decorator_serde():
+ from airflow.decorators import task
+ from airflow.models.xcom_arg import XComArg
+ from airflow.serialization.serialized_objects import _XComRef
+
+ with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
+ op1 = BaseOperator(task_id="op1")
+ xcomarg = XComArg(op1, "my_key")
+
+ @task(retry_delay=30)
+ def x(arg1, arg2, arg3, arg4):
+ print(arg1, arg2, arg3, arg4)
+
+ x.partial("foo", arg3=[1, 2, {"a": "b"}]).map({"a": 1, "b": 2}, arg4=xcomarg)
+
+ original = dag.get_task("x")
+
+ serialized = SerializedBaseOperator._serialize(original)
+ assert serialized == {
+ '_is_dummy': False,
+ '_is_mapped': True,
+ '_task_module': 'airflow.decorators.python',
+ '_task_type': '_PythonDecoratedOperator',
+ 'downstream_task_ids': [],
+ 'partial_kwargs': {
+ 'op_args': ["foo"],
+ 'op_kwargs': {'arg3': [1, 2, {"__type": "dict", "__var": {'a': 'b'}}]},
+ 'retry_delay': 30,
+ },
+ 'mapped_kwargs': {
+ 'op_args': [{"__type": "dict", "__var": {'a': 1, 'b': 2}}],
+ 'op_kwargs': {'arg4': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'my_key'}}},
+ },
+ 'task_id': 'x',
+ 'template_ext': [],
+ 'template_fields': ['op_args', 'op_kwargs'],
+ }
+
+ deserialized = SerializedBaseOperator.deserialize_operator(serialized)
+ assert isinstance(deserialized, MappedOperator)
+ assert deserialized.deps is MappedOperator.DEFAULT_DEPS
+
+ assert deserialized.mapped_kwargs == {
+ "op_args": [{"a": 1, "b": 2}],
+ "op_kwargs": {"arg4": _XComRef("op1", "my_key")},
+ }
+ assert deserialized.partial_kwargs == {
+ "retry_delay": 30,
+ "op_args": ["foo"],
+ "op_kwargs": {"arg3": [1, 2, {"a": "b"}]},
+ }
+
+
def test_mapped_task_group_serde():
execution_date = datetime(2020, 1, 1)
diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py
index c36f40887e2b3..72879fb8ee348 100644
--- a/tests/utils/test_helpers.py
+++ b/tests/utils/test_helpers.py
@@ -160,7 +160,7 @@ def test_build_airflow_url_with_query(self):
Test query generated with dag_id and params
"""
query = {"dag_id": "test_dag", "param": "key/to.encode"}
- expected_url = "/graph?dag_id=test_dag¶m=key%2Fto.encode"
+ expected_url = "/dags/test_dag/graph?param=key%2Fto.encode"
from airflow.www.app import cached_app
diff --git a/tests/utils/test_operator_resources.py b/tests/utils/test_operator_resources.py
new file mode 100644
index 0000000000000..fb15580b2173e
--- /dev/null
+++ b/tests/utils/test_operator_resources.py
@@ -0,0 +1,35 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+from airflow.utils.operator_resources import Resources
+
+
+class TestResources(unittest.TestCase):
+ def test_resource_eq(self):
+ r = Resources(cpus=0.1, ram=2048)
+ assert r not in [{}, [], None]
+ assert r == r
+
+ r2 = Resources(cpus=0.1, ram=2048)
+ assert r == r2
+ assert r2 == r
+
+ r3 = Resources(cpus=0.2, ram=2048)
+ assert r != r3
diff --git a/tests/www/views/test_views_decorators.py b/tests/www/views/test_views_decorators.py
index 9b5383f2d654e..c211806bd1b04 100644
--- a/tests/www/views/test_views_decorators.py
+++ b/tests/www/views/test_views_decorators.py
@@ -115,7 +115,7 @@ def _check_last_log(session, dag_id, event, execution_date):
def test_action_logging_get(session, admin_client):
url = (
- f'graph?dag_id=example_bash_operator&'
+ f'dags/example_bash_operator/graph?'
f'execution_date={urllib.parse.quote_plus(str(EXAMPLE_DAG_DEFAULT_DATE))}'
)
resp = admin_client.get(url, follow_redirects=True)
@@ -131,6 +131,24 @@ def test_action_logging_get(session, admin_client):
)
+def test_action_logging_get_legacy_view(session, admin_client):
+ url = (
+ f'graph?dag_id=example_bash_operator&'
+ f'execution_date={urllib.parse.quote_plus(str(EXAMPLE_DAG_DEFAULT_DATE))}'
+ )
+ resp = admin_client.get(url, follow_redirects=True)
+ check_content_in_response('runme_1', resp)
+
+ # In mysql backend, this commit() is needed to write down the logs
+ session.commit()
+ _check_last_log(
+ session,
+ dag_id="example_bash_operator",
+ event="legacy_graph",
+ execution_date=EXAMPLE_DAG_DEFAULT_DATE,
+ )
+
+
def test_action_logging_post(session, admin_client):
form = dict(
task_id="runme_1",
diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py
index d433542fd3264..bd0c7badd46f2 100644
--- a/tests/www/views/test_views_tasks.py
+++ b/tests/www/views/test_views_tasks.py
@@ -139,16 +139,26 @@ def client_ti_without_dag_edit(app):
pytest.param(
'dag_details?dag_id=example_bash_operator',
['DAG Details'],
- id="dag-details",
+ id="dag-details-url-param",
),
pytest.param(
'dag_details?dag_id=example_subdag_operator.section-1',
['DAG Details'],
+ id="dag-details-subdag-url-param",
+ ),
+ pytest.param(
+ 'dags/example_subdag_operator.section-1/details',
+ ['DAG Details'],
id="dag-details-subdag",
),
pytest.param(
'graph?dag_id=example_bash_operator',
['runme_1'],
+ id='graph-url-param',
+ ),
+ pytest.param(
+ 'dags/example_bash_operator/graph',
+ ['runme_1'],
id='graph',
),
pytest.param(
@@ -156,34 +166,69 @@ def client_ti_without_dag_edit(app):
['runme_1'],
id='tree',
),
+ pytest.param(
+ 'dags/example_bash_operator/grid',
+ ['runme_1'],
+ id='grid',
+ ),
pytest.param(
'tree?dag_id=example_subdag_operator.section-1',
['section-1-task-1'],
- id="tree-subdag",
+ id="tree-subdag-url-param",
+ ),
+ pytest.param(
+ 'dags/example_subdag_operator.section-1/grid',
+ ['section-1-task-1'],
+ id="grid-subdag",
),
pytest.param(
'duration?days=30&dag_id=example_bash_operator',
['example_bash_operator'],
+ id='duration-url-param',
+ ),
+ pytest.param(
+ 'dags/example_bash_operator/duration?days=30',
+ ['example_bash_operator'],
id='duration',
),
pytest.param(
'duration?days=30&dag_id=missing_dag',
['seems to be missing'],
+ id='duration-missing-url-param',
+ ),
+ pytest.param(
+ 'dags/missing_dag/duration?days=30',
+ ['seems to be missing'],
id='duration-missing',
),
pytest.param(
'tries?days=30&dag_id=example_bash_operator',
['example_bash_operator'],
+ id='tries-url-param',
+ ),
+ pytest.param(
+ 'dags/example_bash_operator/tries?days=30',
+ ['example_bash_operator'],
id='tries',
),
pytest.param(
'landing_times?days=30&dag_id=example_bash_operator',
['example_bash_operator'],
+ id='landing-times-url-param',
+ ),
+ pytest.param(
+ 'dags/example_bash_operator/landing-times?days=30',
+ ['example_bash_operator'],
id='landing-times',
),
pytest.param(
'gantt?dag_id=example_bash_operator',
['example_bash_operator'],
+ id="gantt-url-param",
+ ),
+ pytest.param(
+ 'dags/example_bash_operator/gantt',
+ ['example_bash_operator'],
id="gantt",
),
pytest.param(
@@ -196,21 +241,41 @@ def client_ti_without_dag_edit(app):
pytest.param(
"graph?dag_id=example_bash_operator",
["example_bash_operator"],
+ id="existing-dagbag-graph-url-param",
+ ),
+ pytest.param(
+ "dags/example_bash_operator/graph",
+ ["example_bash_operator"],
id="existing-dagbag-graph",
),
pytest.param(
"tree?dag_id=example_bash_operator",
["example_bash_operator"],
- id="existing-dagbag-tree",
+ id="existing-dagbag-tree-url-param",
+ ),
+ pytest.param(
+ "dags/example_bash_operator/grid",
+ ["example_bash_operator"],
+ id="existing-dagbag-grid",
),
pytest.param(
"calendar?dag_id=example_bash_operator",
["example_bash_operator"],
+ id="existing-dagbag-calendar-url-param",
+ ),
+ pytest.param(
+ "dags/example_bash_operator/calendar",
+ ["example_bash_operator"],
id="existing-dagbag-calendar",
),
pytest.param(
"dag_details?dag_id=example_bash_operator",
["example_bash_operator"],
+ id="existing-dagbag-dag-details-url-param",
+ ),
+ pytest.param(
+ "dags/example_bash_operator/details",
+ ["example_bash_operator"],
id="existing-dagbag-dag-details",
),
pytest.param(
@@ -274,7 +339,7 @@ def test_tree_trigger_origin_tree_view(app, admin_client):
url = 'tree?dag_id=test_tree_view'
resp = admin_client.get(url, follow_redirects=True)
- params = {'dag_id': 'test_tree_view', 'origin': '/tree?dag_id=test_tree_view'}
+ params = {'dag_id': 'test_tree_view', 'origin': '/dags/test_tree_view/grid'}
href = f"/trigger?{html.escape(urllib.parse.urlencode(params))}"
check_content_in_response(href, resp)
@@ -288,9 +353,9 @@ def test_graph_trigger_origin_graph_view(app, admin_client):
state=State.RUNNING,
)
- url = 'graph?dag_id=test_tree_view'
+ url = '/dags/test_tree_view/graph'
resp = admin_client.get(url, follow_redirects=True)
- params = {'dag_id': 'test_tree_view', 'origin': '/graph?dag_id=test_tree_view'}
+ params = {'dag_id': 'test_tree_view', 'origin': '/dags/test_tree_view/graph'}
href = f"/trigger?{html.escape(urllib.parse.urlencode(params))}"
check_content_in_response(href, resp)
@@ -304,9 +369,9 @@ def test_dag_details_trigger_origin_dag_details_view(app, admin_client):
state=State.RUNNING,
)
- url = 'dag_details?dag_id=test_graph_view'
+ url = '/dags/test_graph_view/details'
resp = admin_client.get(url, follow_redirects=True)
- params = {'dag_id': 'test_graph_view', 'origin': '/dag_details?dag_id=test_graph_view'}
+ params = {'dag_id': 'test_graph_view', 'origin': '/dags/test_graph_view/details'}
href = f"/trigger?{html.escape(urllib.parse.urlencode(params))}"
check_content_in_response(href, resp)
@@ -348,7 +413,7 @@ def test_code_from_db(admin_client):
dag = DagBag(include_examples=True).get_dag("example_bash_operator")
DagCode(dag.fileloc, DagCode._get_code_from_file(dag.fileloc)).sync_to_db()
url = 'code?dag_id=example_bash_operator'
- resp = admin_client.get(url)
+ resp = admin_client.get(url, follow_redirects=True)
check_content_not_in_response('Failed to load DAG file Code', resp)
check_content_in_response('example_bash_operator', resp)
@@ -358,7 +423,7 @@ def test_code_from_db_all_example_dags(admin_client):
for dag in dagbag.dags.values():
DagCode(dag.fileloc, DagCode._get_code_from_file(dag.fileloc)).sync_to_db()
url = 'code?dag_id=example_bash_operator'
- resp = admin_client.get(url)
+ resp = admin_client.get(url, follow_redirects=True)
check_content_not_in_response('Failed to load DAG file Code', resp)
check_content_in_response('example_bash_operator', resp)