diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 259303c3073e8..a73abec8487f8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -232,7 +232,7 @@ repos: name: Run pydocstyle args: - --convention=pep257 - - --add-ignore=D100,D102,D103,D104,D105,D107,D205,D400,D401 + - --add-ignore=D100,D102,D103,D104,D105,D107,D202,D205,D400,D401 exclude: | (?x) ^tests/.*\.py$| diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 8ef738f4f4ad6..83ea6b94ccff5 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1850,6 +1850,13 @@ type: string example: ~ default: "300" + - name: zombie_detection_interval + description: | + How often (in seconds) should the scheduler check for zombie tasks. + version_added: 2.3.0 + type: float + example: ~ + default: "10.0" - name: catchup_by_default description: | Turn off scheduler catchup by setting this to ``False``. diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 520ab4442850d..55161a55d5e71 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -923,6 +923,9 @@ child_process_log_directory = {AIRFLOW_HOME}/logs/scheduler # associated task instance as failed and will re-schedule the task. scheduler_zombie_task_threshold = 300 +# How often (in seconds) should the scheduler check for zombie tasks. +zombie_detection_interval = 10.0 + # Turn off scheduler catchup by setting this to ``False``. # Default behavior is unchanged and # Command Line Backfills still work, but the scheduler diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index 3b8a998551aca..33b219ccd58ea 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -34,7 +34,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union, cast from setproctitle import setproctitle -from sqlalchemy import or_ from tabulate import tabulate import airflow.models @@ -42,17 +41,15 @@ from airflow.dag_processing.processor import DagFileProcessorProcess from airflow.models import DagModel, errors from airflow.models.serialized_dag import SerializedDagModel -from airflow.models.taskinstance import SimpleTaskInstance from airflow.stats import Stats from airflow.utils import timezone -from airflow.utils.callback_requests import CallbackRequest, SlaCallbackRequest, TaskCallbackRequest +from airflow.utils.callback_requests import CallbackRequest, SlaCallbackRequest from airflow.utils.file import list_py_file_paths, might_contain_dag from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.mixins import MultiprocessingStartMethodMixin from airflow.utils.net import get_hostname from airflow.utils.process_utils import kill_child_processes_by_pids, reap_process_group from airflow.utils.session import provide_session -from airflow.utils.state import State if TYPE_CHECKING: import pathlib @@ -434,8 +431,6 @@ def __init__( # How often to print out DAG file processing stats to the log. Default to # 30 seconds. self.print_stats_interval = conf.getint('scheduler', 'print_stats_interval') - # How many seconds do we wait for tasks to heartbeat before mark them as zombies. - self._zombie_threshold_secs = conf.getint('scheduler', 'scheduler_zombie_task_threshold') # Map from file path to the processor self._processors: Dict[str, DagFileProcessorProcess] = {} @@ -445,13 +440,10 @@ def __init__( # Map from file path to stats about the file self._file_stats: Dict[str, DagFileStat] = {} - self._last_zombie_query_time = None # Last time that the DAG dir was traversed to look for files self.last_dag_dir_refresh_time = timezone.make_aware(datetime.fromtimestamp(0)) # Last time stats were printed self.last_stat_print_time = 0 - # TODO: Remove magic number - self._zombie_query_interval = 10 # How long to wait before timing out a process to parse a DAG file self._processor_timeout = processor_timeout @@ -566,7 +558,6 @@ def _run_parsing_loop(self): self._processors.pop(processor.file_path) self._refresh_dag_dir() - self._find_zombies() self._kill_timed_out_processors() @@ -1023,53 +1014,6 @@ def prepare_file_path_queue(self): self._file_path_queue.extend(files_paths_to_queue) - @provide_session - def _find_zombies(self, session): - """ - Find zombie task instances, which are tasks haven't heartbeated for too long - and update the current zombie list. - """ - now = timezone.utcnow() - if ( - not self._last_zombie_query_time - or (now - self._last_zombie_query_time).total_seconds() > self._zombie_query_interval - ): - # to avoid circular imports - from airflow.jobs.local_task_job import LocalTaskJob as LJ - - self.log.info("Finding 'running' jobs without a recent heartbeat") - TI = airflow.models.TaskInstance - DM = airflow.models.DagModel - limit_dttm = timezone.utcnow() - timedelta(seconds=self._zombie_threshold_secs) - - zombies = ( - session.query(TI, DM.fileloc) - .join(LJ, TI.job_id == LJ.id) - .join(DM, TI.dag_id == DM.dag_id) - .filter(TI.state == State.RUNNING) - .filter( - or_( - LJ.state != State.RUNNING, - LJ.latest_heartbeat < limit_dttm, - ) - ) - .all() - ) - - if zombies: - self.log.warning("Failing (%s) jobs without heartbeat after %s", len(zombies), limit_dttm) - - self._last_zombie_query_time = timezone.utcnow() - for ti, file_loc in zombies: - request = TaskCallbackRequest( - full_filepath=file_loc, - simple_task_instance=SimpleTaskInstance(ti), - msg=f"Detected {ti} as zombie", - ) - self.log.error("Detected zombie job: %s", request) - self._add_callback_to_queue(request) - Stats.incr('zombies_killed') - def _kill_timed_out_processors(self): """Kill any file processors that timeout to defend against process hangs.""" now = timezone.utcnow() diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 9cf423fb69e11..53a12c62d2318 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -280,30 +280,88 @@ def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any], valid_names names = ", ".join(repr(n) for n in unknown_args) raise TypeError(f'{funcname} got unexpected keyword arguments {names}') - def map( - self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs - ) -> XComArg: + def map(self, *args, **kwargs) -> XComArg: self._validate_arg_names("map", kwargs) - dag = dag or DagContext.get_current_dag() - task_group = task_group or TaskGroupContext.get_current_task_group(dag) - task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group) - operator = MappedOperator.from_decorator( - decorator=self, + partial_kwargs = self.kwargs.copy() + dag = partial_kwargs.pop("dag", DagContext.get_current_dag()) + task_group = partial_kwargs.pop("task_group", TaskGroupContext.get_current_task_group(dag)) + task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group) + + # Unfortunately attrs's type hinting support does not work well with + # subclassing; it complains that arguments forwarded to the superclass + # are "unexpected" (they are fine at runtime). + operator = cast(Any, DecoratedMappedOperator)( + operator_class=self.operator_class, + partial_kwargs=partial_kwargs, + mapped_kwargs={}, + task_id=task_id, dag=dag, task_group=task_group, - task_id=task_id, - mapped_kwargs=kwargs, + deps=MappedOperator._deps(self.operator_class.deps), + multiple_outputs=self.multiple_outputs, + python_callable=self.function, ) + + operator.mapped_kwargs["op_args"] = list(args) + operator.mapped_kwargs["op_kwargs"] = kwargs + + for arg in itertools.chain(args, kwargs.values()): + XComArg.apply_upstream_relationship(operator, arg) return XComArg(operator=operator) - def partial( - self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs - ) -> "_TaskDecorator[Function, OperatorSubclass]": - self._validate_arg_names("partial", kwargs, {'task_id'}) - partial_kwargs = self.kwargs.copy() - partial_kwargs.update(kwargs) - return attr.evolve(self, kwargs=partial_kwargs) + def partial(self, *args, **kwargs) -> "_TaskDecorator[Function, OperatorSubclass]": + self._validate_arg_names("partial", kwargs) + + op_args = self.kwargs.get("op_args", []) + op_args.extend(args) + + op_kwargs = self.kwargs.get("op_kwargs", {}) + op_kwargs = _merge_kwargs(op_kwargs, kwargs, fail_reason="duplicate partial") + + return attr.evolve(self, kwargs={**self.kwargs, "op_args": op_args, "op_kwargs": op_kwargs}) + + +def _merge_kwargs( + kwargs1: Dict[str, XComArg], + kwargs2: Dict[str, XComArg], + *, + fail_reason: str, +) -> Dict[str, XComArg]: + duplicated_keys = set(kwargs1).intersection(kwargs2) + if len(duplicated_keys) == 1: + raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}") + elif duplicated_keys: + duplicated_keys_display = ", ".join(sorted(duplicated_keys)) + raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}") + return {**kwargs1, **kwargs2} + + +@attr.define(kw_only=True) +class DecoratedMappedOperator(MappedOperator): + """MappedOperator implementation for @task-decorated task function.""" + + multiple_outputs: bool + python_callable: Callable + + def create_unmapped_operator(self, dag: "DAG") -> BaseOperator: + assert not isinstance(self.operator_class, str) + op_args = self.partial_kwargs.pop("op_args", []) + self.mapped_kwargs.pop("op_args", []) + op_kwargs = _merge_kwargs( + self.partial_kwargs.pop("op_kwargs", {}), + self.mapped_kwargs.pop("op_kwargs", {}), + fail_reason="mapping already partial", + ) + return self.operator_class( + dag=dag, + task_id=self.task_id, + op_args=op_args, + op_kwargs=op_kwargs, + multiple_outputs=self.multiple_outputs, + python_callable=self.python_callable, + **self.partial_kwargs, + **self.mapped_kwargs, + ) class Task(Generic[Function]): diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 7a6e3efd74ba9..62116553c675e 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -38,6 +38,7 @@ from airflow.dag_processing.manager import DagFileProcessorAgent from airflow.executors.executor_loader import UNPICKLEABLE_EXECUTORS from airflow.jobs.base_job import BaseJob +from airflow.jobs.local_task_job import LocalTaskJob from airflow.models import DAG from airflow.models.dag import DagModel from airflow.models.dagbag import DagBag @@ -123,6 +124,8 @@ def __init__( ) scheduler_idle_sleep_time = processor_poll_interval self._scheduler_idle_sleep_time = scheduler_idle_sleep_time + # How many seconds do we wait for tasks to heartbeat before mark them as zombies. + self._zombie_threshold_secs = conf.getint('scheduler', 'scheduler_zombie_task_threshold') self.do_pickle = do_pickle super().__init__(*args, **kwargs) @@ -739,6 +742,11 @@ def _run_scheduler_loop(self) -> None: self._emit_pool_metrics, ) + timers.call_regular_interval( + conf.getfloat('scheduler', 'zombie_detection_interval', fallback=10.0), + self._find_zombies, + ) + for loop_count in itertools.count(start=1): with Stats.timer() as timer: @@ -1259,3 +1267,39 @@ def check_trigger_timeouts(self, session: Session = None): ) if num_timed_out_tasks: self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks) + + @provide_session + def _find_zombies(self, session): + """ + Find zombie task instances, which are tasks haven't heartbeated for too long + and update the current zombie list. + """ + self.log.debug("Finding 'running' jobs without a recent heartbeat") + limit_dttm = timezone.utcnow() - timedelta(seconds=self._zombie_threshold_secs) + + zombies = ( + session.query(TaskInstance, DagModel.fileloc) + .join(LocalTaskJob, TaskInstance.job_id == LocalTaskJob.id) + .join(DagModel, TaskInstance.dag_id == DagModel.dag_id) + .filter(TaskInstance.state == State.RUNNING) + .filter( + or_( + LocalTaskJob.state != State.RUNNING, + LocalTaskJob.latest_heartbeat < limit_dttm, + ) + ) + .all() + ) + + if zombies: + self.log.warning("Failing (%s) jobs without heartbeat after %s", len(zombies), limit_dttm) + + for ti, file_loc in zombies: + request = TaskCallbackRequest( + full_filepath=file_loc, + simple_task_instance=SimpleTaskInstance(ti), + msg=f"Detected {ti} as zombie", + ) + self.log.error("Detected zombie job: %s", request) + self.processor_agent.send_callback_to_execute(request) + Stats.incr('zombies_killed') diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 34c84128391dc..35a0fbb53fe8a 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -82,7 +82,6 @@ from airflow.utils.weight_rule import WeightRule if TYPE_CHECKING: - from airflow.decorators.base import _TaskDecorator from airflow.models.dag import DAG from airflow.utils.task_group import TaskGroup @@ -243,7 +242,7 @@ def __new__(cls, name, bases, namespace, **kwargs): return new_cls # The class level partial function. This is what handles the actual mapping - def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs): + def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs) -> "MappedOperator": operator_class = cast("Type[BaseOperator]", cls) # Validate that the args we passed are known -- at call/DAG parse time, not run time! _validate_kwarg_names_for_mapping(operator_class, "partial", kwargs) @@ -671,7 +670,7 @@ def __init__( ) self.trigger_rule = trigger_rule - self.depends_on_past = depends_on_past + self.depends_on_past: bool = depends_on_past self.wait_for_downstream = wait_for_downstream if wait_for_downstream: self.depends_on_past = True @@ -714,7 +713,7 @@ def __init__( stacklevel=2, ) max_active_tis_per_dag = task_concurrency - self.max_active_tis_per_dag = max_active_tis_per_dag + self.max_active_tis_per_dag: Optional[int] = max_active_tis_per_dag self.do_xcom_push = do_xcom_push self.doc_md = doc_md @@ -1632,7 +1631,7 @@ def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> dag._remove_task(operator.task_id) operator_init_kwargs: dict = operator._BaseOperator__init_kwargs # type: ignore - return MappedOperator( + return cls( operator_class=type(operator), task_id=operator.task_id, task_group=task_group, @@ -1648,37 +1647,6 @@ def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> deps=cls._deps(operator.deps), ) - @classmethod - def from_decorator( - cls, - *, - decorator: "_TaskDecorator", - dag: Optional["DAG"], - task_group: Optional["TaskGroup"], - task_id: str, - mapped_kwargs: Dict[str, Any], - ) -> "MappedOperator": - """Create a mapped operator from a task decorator. - - Different from ``from_operator``, this DOES NOT validate ``mapped_kwargs``. - The task decorator calling this should be responsible for validation. - """ - from airflow.models.xcom_arg import XComArg - - operator = MappedOperator( - operator_class=decorator.operator_class, - partial_kwargs=decorator.kwargs, - mapped_kwargs={}, - task_id=task_id, - dag=dag, - task_group=task_group, - deps=cls._deps(decorator.operator_class.deps), - ) - operator.mapped_kwargs.update(mapped_kwargs) - for arg in mapped_kwargs.values(): - XComArg.apply_upstream_relationship(operator, arg) - return operator - @classmethod def _deps(cls, deps: Iterable[BaseTIDep]): if deps is BaseOperator.deps: @@ -1749,7 +1717,7 @@ def inherits_from_dummy_operator(self): @classmethod def get_serialized_fields(cls): if cls.__serialized_fields is None: - fields_dict = attr.fields_dict(cls) + fields_dict = attr.fields_dict(MappedOperator) cls.__serialized_fields = frozenset( fields_dict.keys() - { @@ -1902,22 +1870,17 @@ def expand_mapped_task( return ret - def unmap(self) -> BaseOperator: - """Get the "normal" Operator after applying the current mapping""" + def create_unmapped_operator(self, dag: "DAG") -> BaseOperator: assert not isinstance(self.operator_class, str) + return self.operator_class(dag=dag, task_id=self.task_id, **self.partial_kwargs, **self.mapped_kwargs) + def unmap(self) -> BaseOperator: + """Get the "normal" Operator after applying the current mapping""" dag = self.get_dag() if not dag: - raise RuntimeError("Cannot unmapp a task unless it has a dag") - - args = { - **self.partial_kwargs, - **self.mapped_kwargs, - } + raise RuntimeError("Cannot unmap a task unless it has a DAG") dag._remove_task(self.task_id) - task = self.operator_class(task_id=self.task_id, dag=self.dag, **args) - - return task + return self.create_unmapped_operator(dag) # TODO: Deprecate for Airflow 3.0 diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index eac62dcd39506..0136d7fc13fd9 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -89,7 +89,6 @@ class DagBag(LoggingMixin): """ DAGBAG_IMPORT_TIMEOUT = conf.getfloat('core', 'DAGBAG_IMPORT_TIMEOUT') - SCHEDULER_ZOMBIE_TASK_THRESHOLD = conf.getint('scheduler', 'scheduler_zombie_task_threshold') def __init__( self, diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 4996b9a7db073..f10032dfc2964 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1713,7 +1713,7 @@ def handle_failure( test_mode: Optional[bool] = None, force_fail: bool = False, error_file: Optional[str] = None, - session=NEW_SESSION, + session: Session = NEW_SESSION, ) -> None: """Handle Failure for the TaskInstance""" if test_mode is None: diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_templated.py b/airflow/providers/amazon/aws/example_dags/example_eks_templated.py index afd19912aa1a9..26e10d7e5f4ce 100644 --- a/airflow/providers/amazon/aws/example_dags/example_eks_templated.py +++ b/airflow/providers/amazon/aws/example_dags/example_eks_templated.py @@ -14,11 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# Ignore missing args provided by default_args -# type: ignore[call-arg] - -import os +import json from datetime import datetime from airflow.models.dag import DAG @@ -37,7 +33,6 @@ { "cluster_name": "templated-cluster", "cluster_role_arn": "arn:aws:iam::123456789012:role/role_name", - "nodegroup_subnets": ["subnet-12345ab", "subnet-67890cd"], "resources_vpc_config": { "subnetIds": ["subnet-12345ab", "subnet-67890cd"], "endpointPublicAccess": true, @@ -49,25 +44,24 @@ """ with DAG( - dag_id='to-publish-manuals-templated', - default_args={'cluster_name': "{{ dag_run.conf['cluster_name'] }}"}, + dag_id='example_eks_templated', schedule_interval=None, start_date=datetime(2021, 1, 1), - catchup=False, - max_active_runs=1, tags=['example', 'templated'], + catchup=False, # render_template_as_native_obj=True is what converts the Jinja to Python objects, instead of a string. render_template_as_native_obj=True, ) as dag: - SUBNETS = os.environ.get('EKS_DEMO_SUBNETS', 'subnet-12345ab subnet-67890cd').split(' ') - VPC_CONFIG = { - 'subnetIds': SUBNETS, - 'endpointPublicAccess': True, - 'endpointPrivateAccess': False, - } + + CLUSTER_NAME = "{{ dag_run.conf['cluster_name'] }}" + NODEGROUP_NAME = "{{ dag_run.conf['nodegroup_name'] }}" + VPC_CONFIG = json.loads("{{ dag_run.conf['resources_vpc_config'] }}") + SUBNETS = VPC_CONFIG['subnetIds'] + # Create an Amazon EKS Cluster control plane without attaching a compute service. create_cluster = EksCreateClusterOperator( task_id='create_eks_cluster', + cluster_name=CLUSTER_NAME, compute=None, cluster_role_arn="{{ dag_run.conf['cluster_role_arn'] }}", resources_vpc_config=VPC_CONFIG, @@ -75,24 +69,28 @@ await_create_cluster = EksClusterStateSensor( task_id='wait_for_create_cluster', + cluster_name=CLUSTER_NAME, target_state=ClusterStates.ACTIVE, ) create_nodegroup = EksCreateNodegroupOperator( task_id='create_eks_nodegroup', - nodegroup_name="{{ dag_run.conf['nodegroup_name'] }}", - nodegroup_subnets="{{ dag_run.conf['nodegroup_subnets'] }}", + cluster_name=CLUSTER_NAME, + nodegroup_name=NODEGROUP_NAME, + nodegroup_subnets=SUBNETS, nodegroup_role_arn="{{ dag_run.conf['nodegroup_role_arn'] }}", ) await_create_nodegroup = EksNodegroupStateSensor( task_id='wait_for_create_nodegroup', - nodegroup_name="{{ dag_run.conf['nodegroup_name'] }}", + cluster_name=CLUSTER_NAME, + nodegroup_name=NODEGROUP_NAME, target_state=NodegroupStates.ACTIVE, ) start_pod = EksPodOperator( task_id="run_pod", + cluster_name=CLUSTER_NAME, pod_name="run_pod", image="amazon/aws-cli:latest", cmds=["sh", "-c", "ls"], @@ -104,21 +102,25 @@ delete_nodegroup = EksDeleteNodegroupOperator( task_id='delete_eks_nodegroup', - nodegroup_name="{{ dag_run.conf['nodegroup_name'] }}", + cluster_name=CLUSTER_NAME, + nodegroup_name=NODEGROUP_NAME, ) await_delete_nodegroup = EksNodegroupStateSensor( task_id='wait_for_delete_nodegroup', - nodegroup_name="{{ dag_run.conf['nodegroup_name'] }}", + cluster_name=CLUSTER_NAME, + nodegroup_name=NODEGROUP_NAME, target_state=NodegroupStates.NONEXISTENT, ) delete_cluster = EksDeleteClusterOperator( task_id='delete_eks_cluster', + cluster_name=CLUSTER_NAME, ) await_delete_cluster = EksClusterStateSensor( task_id='wait_for_delete_cluster', + cluster_name=CLUSTER_NAME, target_state=ClusterStates.NONEXISTENT, ) diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py index 4107058b5ac20..e08e6525e6fc8 100644 --- a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py +++ b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py @@ -43,19 +43,18 @@ with DAG( - dag_id='example-create-cluster-and-fargate-all-in-one', - default_args={'cluster_name': CLUSTER_NAME}, + dag_id='example_eks_with_fargate_in_one_step', schedule_interval=None, start_date=datetime(2021, 1, 1), - catchup=False, - max_active_runs=1, tags=['example'], + catchup=False, ) as dag: # [START howto_operator_eks_create_cluster_with_fargate_profile] # Create an Amazon EKS cluster control plane and an AWS Fargate compute platform in one step. create_cluster_and_fargate_profile = EksCreateClusterOperator( task_id='create_eks_cluster_and_fargate_profile', + cluster_name=CLUSTER_NAME, cluster_role_arn=ROLE_ARN, resources_vpc_config=VPC_CONFIG, compute='fargate', @@ -68,6 +67,7 @@ await_create_fargate_profile = EksFargateProfileStateSensor( task_id='wait_for_create_fargate_profile', + cluster_name=CLUSTER_NAME, fargate_profile_name=FARGATE_PROFILE_NAME, target_state=FargateProfileStates.ACTIVE, ) @@ -75,6 +75,7 @@ start_pod = EksPodOperator( task_id="run_pod", pod_name="run_pod", + cluster_name=CLUSTER_NAME, image="amazon/aws-cli:latest", cmds=["sh", "-c", "echo Test Airflow; date"], labels={"demo": "hello_world"}, @@ -86,11 +87,14 @@ # An Amazon EKS cluster can not be deleted with attached resources such as nodegroups or Fargate profiles. # Setting the `force` to `True` will delete any attached resources before deleting the cluster. delete_all = EksDeleteClusterOperator( - task_id='delete_fargate_profile_and_cluster', force_delete_compute=True + task_id='delete_fargate_profile_and_cluster', + cluster_name=CLUSTER_NAME, + force_delete_compute=True, ) await_delete_cluster = EksClusterStateSensor( task_id='wait_for_delete_cluster', + cluster_name=CLUSTER_NAME, target_state=ClusterStates.NONEXISTENT, ) diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py index e58e2de729b62..3ca3b2eb87728 100644 --- a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py +++ b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py @@ -46,18 +46,17 @@ with DAG( - dag_id='example_eks_with_fargate_profile_dag', - default_args={'cluster_name': CLUSTER_NAME}, + dag_id='example_eks_with_fargate_profile', schedule_interval=None, start_date=datetime(2021, 1, 1), - catchup=False, - max_active_runs=1, tags=['example'], + catchup=False, ) as dag: # Create an Amazon EKS Cluster control plane without attaching a compute service. create_cluster = EksCreateClusterOperator( task_id='create_eks_cluster', + cluster_name=CLUSTER_NAME, cluster_role_arn=ROLE_ARN, resources_vpc_config=VPC_CONFIG, compute=None, @@ -65,26 +64,32 @@ await_create_cluster = EksClusterStateSensor( task_id='wait_for_create_cluster', + cluster_name=CLUSTER_NAME, target_state=ClusterStates.ACTIVE, ) # [START howto_operator_eks_create_fargate_profile] create_fargate_profile = EksCreateFargateProfileOperator( task_id='create_eks_fargate_profile', + cluster_name=CLUSTER_NAME, pod_execution_role_arn=ROLE_ARN, fargate_profile_name=FARGATE_PROFILE_NAME, selectors=SELECTORS, ) # [END howto_operator_eks_create_fargate_profile] + # [START howto_sensor_eks_fargate] await_create_fargate_profile = EksFargateProfileStateSensor( task_id='wait_for_create_fargate_profile', + cluster_name=CLUSTER_NAME, fargate_profile_name=FARGATE_PROFILE_NAME, target_state=FargateProfileStates.ACTIVE, ) + # [END howto_sensor_eks_fargate] start_pod = EksPodOperator( task_id="run_pod", + cluster_name=CLUSTER_NAME, pod_name="run_pod", image="amazon/aws-cli:latest", cmds=["sh", "-c", "echo Test Airflow; date"], @@ -97,20 +102,26 @@ # [START howto_operator_eks_delete_fargate_profile] delete_fargate_profile = EksDeleteFargateProfileOperator( task_id='delete_eks_fargate_profile', + cluster_name=CLUSTER_NAME, fargate_profile_name=FARGATE_PROFILE_NAME, ) # [END howto_operator_eks_delete_fargate_profile] await_delete_fargate_profile = EksFargateProfileStateSensor( task_id='wait_for_delete_fargate_profile', + cluster_name=CLUSTER_NAME, fargate_profile_name=FARGATE_PROFILE_NAME, target_state=FargateProfileStates.NONEXISTENT, ) - delete_cluster = EksDeleteClusterOperator(task_id='delete_eks_cluster') + delete_cluster = EksDeleteClusterOperator( + task_id='delete_eks_cluster', + cluster_name=CLUSTER_NAME, + ) await_delete_cluster = EksClusterStateSensor( task_id='wait_for_delete_cluster', + cluster_name=CLUSTER_NAME, target_state=ClusterStates.NONEXISTENT, ) diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py index f19eec622f295..38d1bd1ad4c2f 100644 --- a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py +++ b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py @@ -42,19 +42,18 @@ with DAG( - dag_id='example_eks_using_defaults_dag', - default_args={'cluster_name': CLUSTER_NAME}, + dag_id='example_eks_with_nodegroup_in_one_step', schedule_interval=None, start_date=datetime(2021, 1, 1), - catchup=False, - max_active_runs=1, tags=['example'], + catchup=False, ) as dag: # [START howto_operator_eks_create_cluster_with_nodegroup] # Create an Amazon EKS cluster control plane and an EKS nodegroup compute platform in one step. create_cluster_and_nodegroup = EksCreateClusterOperator( task_id='create_eks_cluster_and_nodegroup', + cluster_name=CLUSTER_NAME, nodegroup_name=NODEGROUP_NAME, cluster_role_arn=ROLE_ARN, nodegroup_role_arn=ROLE_ARN, @@ -68,12 +67,14 @@ await_create_nodegroup = EksNodegroupStateSensor( task_id='wait_for_create_nodegroup', + cluster_name=CLUSTER_NAME, nodegroup_name=NODEGROUP_NAME, target_state=NodegroupStates.ACTIVE, ) start_pod = EksPodOperator( task_id="run_pod", + cluster_name=CLUSTER_NAME, pod_name="run_pod", image="amazon/aws-cli:latest", cmds=["sh", "-c", "echo Test Airflow; date"], @@ -86,11 +87,16 @@ # [START howto_operator_eks_force_delete_cluster] # An Amazon EKS cluster can not be deleted with attached resources such as nodegroups or Fargate profiles. # Setting the `force` to `True` will delete any attached resources before deleting the cluster. - delete_all = EksDeleteClusterOperator(task_id='delete_nodegroup_and_cluster', force_delete_compute=True) + delete_all = EksDeleteClusterOperator( + task_id='delete_nodegroup_and_cluster', + cluster_name=CLUSTER_NAME, + force_delete_compute=True, + ) # [END howto_operator_eks_force_delete_cluster] await_delete_cluster = EksClusterStateSensor( task_id='wait_for_delete_cluster', + cluster_name=CLUSTER_NAME, target_state=ClusterStates.NONEXISTENT, ) diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py index 3ec6a3ac459a0..efeeb14e04bfb 100644 --- a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py +++ b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py @@ -45,48 +45,55 @@ with DAG( - dag_id='example_eks_with_nodegroups_dag', - default_args={'cluster_name': CLUSTER_NAME}, + dag_id='example_eks_with_nodegroups', schedule_interval=None, start_date=datetime(2021, 1, 1), - catchup=False, - max_active_runs=1, tags=['example'], + catchup=False, ) as dag: # [START howto_operator_eks_create_cluster] # Create an Amazon EKS Cluster control plane without attaching compute service. create_cluster = EksCreateClusterOperator( task_id='create_eks_cluster', + cluster_name=CLUSTER_NAME, cluster_role_arn=ROLE_ARN, resources_vpc_config=VPC_CONFIG, compute=None, ) # [END howto_operator_eks_create_cluster] + # [START howto_sensor_eks_cluster] await_create_cluster = EksClusterStateSensor( task_id='wait_for_create_cluster', + cluster_name=CLUSTER_NAME, target_state=ClusterStates.ACTIVE, ) + # [END howto_sensor_eks_cluster] # [START howto_operator_eks_create_nodegroup] create_nodegroup = EksCreateNodegroupOperator( task_id='create_eks_nodegroup', + cluster_name=CLUSTER_NAME, nodegroup_name=NODEGROUP_NAME, nodegroup_subnets=SUBNETS, nodegroup_role_arn=ROLE_ARN, ) # [END howto_operator_eks_create_nodegroup] + # [START howto_sensor_eks_nodegroup] await_create_nodegroup = EksNodegroupStateSensor( task_id='wait_for_create_nodegroup', + cluster_name=CLUSTER_NAME, nodegroup_name=NODEGROUP_NAME, target_state=NodegroupStates.ACTIVE, ) + # [END howto_sensor_eks_nodegroup] # [START howto_operator_eks_pod_operator] start_pod = EksPodOperator( task_id="run_pod", + cluster_name=CLUSTER_NAME, pod_name="run_pod", image="amazon/aws-cli:latest", cmds=["sh", "-c", "ls"], @@ -99,22 +106,29 @@ # [START howto_operator_eks_delete_nodegroup] delete_nodegroup = EksDeleteNodegroupOperator( - task_id='delete_eks_nodegroup', nodegroup_name=NODEGROUP_NAME + task_id='delete_eks_nodegroup', + cluster_name=CLUSTER_NAME, + nodegroup_name=NODEGROUP_NAME, ) # [END howto_operator_eks_delete_nodegroup] await_delete_nodegroup = EksNodegroupStateSensor( task_id='wait_for_delete_nodegroup', + cluster_name=CLUSTER_NAME, nodegroup_name=NODEGROUP_NAME, target_state=NodegroupStates.NONEXISTENT, ) # [START howto_operator_eks_delete_cluster] - delete_cluster = EksDeleteClusterOperator(task_id='delete_eks_cluster') + delete_cluster = EksDeleteClusterOperator( + task_id='delete_eks_cluster', + cluster_name=CLUSTER_NAME, + ) # [END howto_operator_eks_delete_cluster] await_delete_cluster = EksClusterStateSensor( task_id='wait_for_delete_cluster', + cluster_name=CLUSTER_NAME, target_state=ClusterStates.NONEXISTENT, ) diff --git a/airflow/providers/amazon/aws/example_dags/example_sns.py b/airflow/providers/amazon/aws/example_dags/example_sns.py new file mode 100644 index 0000000000000..782156b14c3d3 --- /dev/null +++ b/airflow/providers/amazon/aws/example_dags/example_sns.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import datetime +from os import environ + +from airflow import DAG +from airflow.providers.amazon.aws.operators.sns import SnsPublishOperator + +SNS_TOPIC_ARN = environ.get('SNS_TOPIC_ARN', 'arn:aws:sns:us-west-2:123456789012:dummy-topic-name') + +with DAG( + dag_id='example_sns', + schedule_interval=None, + start_date=datetime(2021, 1, 1), + tags=['example'], + catchup=False, +) as dag: + + # [START howto_operator_sns_publish_operator] + publish = SnsPublishOperator( + task_id='publish_message', + target_arn=SNS_TOPIC_ARN, + message='This is a sample message sent to SNS via an Apache Airflow DAG task.', + ) + # [END howto_operator_sns_publish_operator] diff --git a/airflow/providers/amazon/aws/operators/sns.py b/airflow/providers/amazon/aws/operators/sns.py index 48a436b020745..e916798d03386 100644 --- a/airflow/providers/amazon/aws/operators/sns.py +++ b/airflow/providers/amazon/aws/operators/sns.py @@ -30,6 +30,10 @@ class SnsPublishOperator(BaseOperator): """ Publish a message to Amazon SNS. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SnsPublishOperator` + :param aws_conn_id: aws connection to use :param target_arn: either a TopicArn or an EndpointArn :param message: the default message you want to send (templated) diff --git a/airflow/providers/amazon/aws/sensors/eks.py b/airflow/providers/amazon/aws/sensors/eks.py index 7f639b684103f..92ed55da4d31e 100644 --- a/airflow/providers/amazon/aws/sensors/eks.py +++ b/airflow/providers/amazon/aws/sensors/eks.py @@ -60,6 +60,10 @@ class EksClusterStateSensor(BaseSensorOperator): """ Check the state of an Amazon EKS Cluster until it reaches the target state or another terminal state. + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:EksClusterStateSensor` + :param cluster_name: The name of the Cluster to watch. (templated) :param target_state: Target state of the Cluster. (templated) :param region: Which AWS region the connection should use. (templated) @@ -116,6 +120,10 @@ class EksFargateProfileStateSensor(BaseSensorOperator): """ Check the state of an AWS Fargate profile until it reaches the target state or another terminal state. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/sensor:EksFargateProfileStateSensor` + :param cluster_name: The name of the Cluster which the AWS Fargate profile is attached to. (templated) :param fargate_profile_name: The name of the Fargate profile to watch. (templated) :param target_state: Target state of the Fargate profile. (templated) @@ -183,6 +191,10 @@ class EksNodegroupStateSensor(BaseSensorOperator): """ Check the state of an EKS managed node group until it reaches the target state or another terminal state. + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:EksNodegroupStateSensor` + :param cluster_name: The name of the Cluster which the Nodegroup is attached to. (templated) :param nodegroup_name: The name of the Nodegroup to watch. (templated) :param target_state: Target state of the Nodegroup. (templated) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index f6887e0477e6c..962e48b10541c 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -127,6 +127,8 @@ integrations: - integration-name: Amazon Simple Notification Service (SNS) external-doc-url: https://aws.amazon.com/sns/ logo: /integration-logos/aws/Amazon-Simple-Notification-Service-SNS_light-bg@4x.png + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/sns.rst tags: [aws] - integration-name: Amazon Simple Queue Service (SQS) external-doc-url: https://aws.amazon.com/sqs/ diff --git a/airflow/providers/apache/pig/example_dags/example_pig.py b/airflow/providers/apache/pig/example_dags/example_pig.py index e3ab899e2f5be..ed1b34ab0c8a4 100644 --- a/airflow/providers/apache/pig/example_dags/example_pig.py +++ b/airflow/providers/apache/pig/example_dags/example_pig.py @@ -30,9 +30,11 @@ tags=['example'], ) +# [START create_pig] run_this = PigOperator( task_id="run_example_pig_script", pig="ls /;", pig_opts="-x local", dag=dag, ) +# [END create_pig] diff --git a/airflow/providers/apache/pig/provider.yaml b/airflow/providers/apache/pig/provider.yaml index ae1cf0870dc5d..434e307b96ff1 100644 --- a/airflow/providers/apache/pig/provider.yaml +++ b/airflow/providers/apache/pig/provider.yaml @@ -19,7 +19,7 @@ package-name: apache-airflow-providers-apache-pig name: Apache Pig description: | - `Apache Pig `__ + `Apache Pig `__ versions: - 2.0.1 @@ -33,6 +33,8 @@ additional-dependencies: integrations: - integration-name: Apache Pig external-doc-url: https://pig.apache.org/ + how-to-guide: + - /docs/apache-airflow-providers-apache-pig/operators.rst logo: /integration-logos/apache/pig.png tags: [apache] @@ -46,7 +48,8 @@ hooks: python-modules: - airflow.providers.apache.pig.hooks.pig -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ +hook-class-names: + # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - airflow.providers.apache.pig.hooks.pig.PigCliHook connection-types: diff --git a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py index 7e6498cefd5ca..499b331da19c5 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py @@ -90,6 +90,8 @@ class GCSToGCSOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param source_object_required: Whether you want to raise an exception when the source object + doesn't exist. It doesn't have any effect when the source objects are folders or patterns. :Example: @@ -190,6 +192,7 @@ def __init__( maximum_modified_time=None, is_older_than=None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + source_object_required=False, **kwargs, ): super().__init__(**kwargs) @@ -216,6 +219,7 @@ def __init__( self.maximum_modified_time = maximum_modified_time self.is_older_than = is_older_than self.impersonation_chain = impersonation_chain + self.source_object_required = source_object_required def execute(self, context: 'Context'): @@ -313,6 +317,11 @@ def _copy_source_without_wildcard(self, hook, prefix): self._copy_single_object( hook=hook, source_object=prefix, destination_object=self.destination_object ) + elif self.source_object_required: + msg = f"{prefix} does not exist in bucket {self.source_bucket}" + self.log.warning(msg) + raise AirflowException(msg) + for source_obj in objects: if self.destination_object is None: destination_object = source_obj diff --git a/airflow/providers/zendesk/example_dags/example_zendesk_custom_get.py b/airflow/providers/zendesk/example_dags/example_zendesk_custom_get.py index 8332ba86d226b..084ebf4f2860d 100644 --- a/airflow/providers/zendesk/example_dags/example_zendesk_custom_get.py +++ b/airflow/providers/zendesk/example_dags/example_zendesk_custom_get.py @@ -19,11 +19,12 @@ from typing import Dict, List from airflow import DAG -from airflow.operators.python import PythonOperator +from airflow.decorators import task from airflow.providers.zendesk.hooks.zendesk import ZendeskHook -def zendesk_custom_get_request() -> List[Dict]: +@task +def fetch_organizations() -> List[Dict]: hook = ZendeskHook() response = hook.get( url="https://yourdomain.zendesk.com/api/v2/organizations.json", @@ -37,7 +38,4 @@ def zendesk_custom_get_request() -> List[Dict]: start_date=datetime(2021, 1, 1), catchup=False, ) as dag: - fetch_organizations = PythonOperator( - task_id="trigger_zendesk_hook", - python_callable=zendesk_custom_get_request, - ) + fetch_organizations() diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index d6abda7c74899..017f2276964ca 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -16,6 +16,7 @@ # under the License. """Serialized DAG and BaseOperator""" +import contextlib import datetime import enum import logging @@ -168,7 +169,7 @@ def _decode_timetable(var: Dict[str, Any]) -> Timetable: return timetable_class.deserialize(var[Encoding.VAR]) -class _XcomRef(NamedTuple): +class _XComRef(NamedTuple): """ Used to store info needed to create XComArg when deserializing MappedOperator. @@ -497,8 +498,8 @@ def _serialize_xcomarg(cls, arg: XComArg) -> dict: return {"key": arg.key, "task_id": arg.operator.task_id} @classmethod - def _deserialize_xcomref(cls, encoded: dict) -> _XcomRef: - return _XcomRef(key=encoded['key'], task_id=encoded['task_id']) + def _deserialize_xcomref(cls, encoded: dict) -> _XComRef: + return _XComRef(key=encoded['key'], task_id=encoded['task_id']) class DependencyDetector: @@ -566,9 +567,19 @@ def task_type(self, task_type: str): @classmethod def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]: - stock_deps = op.deps is MappedOperator.DEFAULT_DEPS serialize_op = cls._serialize_node(op, include_deps=not stock_deps) + + # Simplify op_kwargs format. It must be a dict, so we flatten it. + with contextlib.suppress(KeyError): + op_kwargs = serialize_op["mapped_kwargs"]["op_kwargs"] + assert op_kwargs[Encoding.TYPE] == DAT.DICT + serialize_op["mapped_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR] + with contextlib.suppress(KeyError): + op_kwargs = serialize_op["partial_kwargs"]["op_kwargs"] + assert op_kwargs[Encoding.TYPE] == DAT.DICT + serialize_op["partial_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR] + # It must be a class at this point for it to work, not a string assert isinstance(op.operator_class, type) serialize_op['_task_type'] = op.operator_class.__name__ @@ -715,7 +726,13 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Union[BaseOperator, elif k == "params": v = cls._deserialize_params_dict(v) elif k in ("mapped_kwargs", "partial_kwargs"): + if "op_kwargs" not in v: + op_kwargs: Optional[dict] = None + else: + op_kwargs = {arg: cls._deserialize(value) for arg, value in v.pop("op_kwargs").items()} v = {arg: cls._deserialize(value) for arg, value in v.items()} + if op_kwargs is not None: + v["op_kwargs"] = op_kwargs elif k in cls._decorated_fields or k not in op.get_serialized_fields(): v = cls._deserialize(v) # else use v as it is @@ -1002,7 +1019,7 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG': if isinstance(task, MappedOperator): for d in (task.mapped_kwargs, task.partial_kwargs): for k, v in d.items(): - if not isinstance(v, _XcomRef): + if not isinstance(v, _XComRef): continue d[k] = XComArg(operator=dag.get_task(v.task_id), key=v.key) diff --git a/airflow/task/task_runner/cgroup_task_runner.py b/airflow/task/task_runner/cgroup_task_runner.py index 15075de64c567..d6c6e53abf935 100644 --- a/airflow/task/task_runner/cgroup_task_runner.py +++ b/airflow/task/task_runner/cgroup_task_runner.py @@ -146,11 +146,11 @@ def start(self): self._mem_mb_limit = resources.ram.qty # Create the memory cgroup - mem_cgroup_node = self._create_cgroup(self.mem_cgroup_name) + self.mem_cgroup_node = self._create_cgroup(self.mem_cgroup_name) self._created_mem_cgroup = True if self._mem_mb_limit > 0: self.log.debug("Setting %s with %s MB of memory", self.mem_cgroup_name, self._mem_mb_limit) - mem_cgroup_node.controller.limit_in_bytes = self._mem_mb_limit * 1024 * 1024 + self.mem_cgroup_node.controller.limit_in_bytes = self._mem_mb_limit * 1024 * 1024 # Create the CPU cgroup cpu_cgroup_node = self._create_cgroup(self.cpu_cgroup_name) @@ -185,11 +185,32 @@ def terminate(self): if self.process and psutil.pid_exists(self.process.pid): reap_process_group(self.process.pid, self.log) + def _log_memory_usage(self, mem_cgroup_node): + def byte_to_gb(num_bytes, precision=2): + return round(num_bytes / (1024 * 1024 * 1024), precision) + + with open(mem_cgroup_node.full_path + '/memory.max_usage_in_bytes') as f: + max_usage_in_bytes = int(f.read().strip()) + + used_gb = byte_to_gb(max_usage_in_bytes) + limit_gb = byte_to_gb(mem_cgroup_node.controller.limit_in_bytes) + + self.log.info( + "Memory max usage of the task is %s GB, while the memory limit is %s GB", used_gb, limit_gb + ) + + if max_usage_in_bytes >= mem_cgroup_node.controller.limit_in_bytes: + self.log.info( + "This task has reached the memory limit allocated by Airflow worker. " + "If it failed, try to optimize the task or reserve more memory." + ) + def on_finish(self): # Let the OOM watcher thread know we're done to avoid false OOM alarms self._finished_running = True # Clean up the cgroups if self._created_mem_cgroup: + self._log_memory_usage(self.mem_cgroup_node) self._delete_cgroup(self.mem_cgroup_name) if self._created_cpu_cgroup: self._delete_cgroup(self.cpu_cgroup_name) diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py index 5bdbf0c14f8dd..a85ab8971cd1c 100644 --- a/airflow/utils/helpers.py +++ b/airflow/utils/helpers.py @@ -35,7 +35,6 @@ Tuple, TypeVar, ) -from urllib import parse from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -258,8 +257,7 @@ def build_airflow_url_with_query(query: Dict[str, Any]) -> str: import flask view = conf.get('webserver', 'dag_default_view').lower() - url = flask.url_for(f"Airflow.{view}") - return f"{url}?{parse.urlencode(query)}" + return flask.url_for(f"Airflow.{view}", **query) # The 'template' argument is typed as Any because the jinja2.Template is too diff --git a/airflow/utils/operator_resources.py b/airflow/utils/operator_resources.py index 8c3263247faa9..28710cddd42cc 100644 --- a/airflow/utils/operator_resources.py +++ b/airflow/utils/operator_resources.py @@ -50,6 +50,8 @@ def __init__(self, name, units_str, qty): self._qty = qty def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented return self.__dict__ == other.__dict__ def __repr__(self): @@ -126,6 +128,8 @@ def __init__( self.gpus = GpuResource(gpus) def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented return self.__dict__ == other.__dict__ def __repr__(self): diff --git a/airflow/www/decorators.py b/airflow/www/decorators.py index 080fe682991c2..47a9847c2eecf 100644 --- a/airflow/www/decorators.py +++ b/airflow/www/decorators.py @@ -20,6 +20,7 @@ import gzip import logging from io import BytesIO as IO +from itertools import chain from typing import Callable, TypeVar, cast import pendulum @@ -48,13 +49,19 @@ def wrapper(*args, **kwargs): user = g.user.username fields_skip_logging = {'csrf_token', '_csrf_token'} + log_fields = { + k: v + for k, v in chain(request.values.items(), request.view_args.items()) + if k not in fields_skip_logging + } + log = Log( event=f.__name__, task_instance=None, owner=user, - extra=str([(k, v) for k, v in request.values.items() if k not in fields_skip_logging]), - task_id=request.values.get('task_id'), - dag_id=request.values.get('dag_id'), + extra=str([(k, log_fields[k]) for k in log_fields]), + task_id=log_fields.get('task_id'), + dag_id=log_fields.get('dag_id'), ) if 'execution_date' in request.values: diff --git a/airflow/www/fab_security/sqla/models.py b/airflow/www/fab_security/sqla/models.py index 69853722d59b7..93a95a45c5e21 100644 --- a/airflow/www/fab_security/sqla/models.py +++ b/airflow/www/fab_security/sqla/models.py @@ -37,7 +37,6 @@ ) from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import backref, relationship -from sqlalchemy.orm.relationships import foreign """ Compatibility note: The models in this file are duplicated from Flask AppBuilder. @@ -140,17 +139,13 @@ class Permission(Model): action_id = Column("permission_id", Integer, ForeignKey("ab_permission.id")) action = relationship( "Action", - primaryjoin=action_id == foreign(Action.id), uselist=False, - backref="permission", lazy="joined", ) resource_id = Column("view_menu_id", Integer, ForeignKey("ab_view_menu.id")) resource = relationship( "Resource", - primaryjoin=resource_id == foreign(Resource.id), uselist=False, - backref="permission", lazy="joined", ) diff --git a/airflow/www/templates/airflow/dag.html b/airflow/www/templates/airflow/dag.html index e57a1f1fc8b3b..0f4b967373da5 100644 --- a/airflow/www/templates/airflow/dag.html +++ b/airflow/www/templates/airflow/dag.html @@ -110,9 +110,9 @@