Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 164 additions & 122 deletions airflow/jobs/scheduler_job.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,7 @@ def resolve_template_files(self) -> None:
self.prepare_template()

@property
def upstream_list(self) -> List[str]:
def upstream_list(self) -> List["BaseOperator"]:
"""@property: list of tasks directly upstream"""
return [self.dag.get_task(tid) for tid in self._upstream_task_ids]

Expand All @@ -989,7 +989,7 @@ def upstream_task_ids(self) -> Set[str]:
return self._upstream_task_ids

@property
def downstream_list(self) -> List[str]:
def downstream_list(self) -> List["BaseOperator"]:
"""@property: list of tasks directly downstream"""
return [self.dag.get_task(tid) for tid in self._downstream_task_ids]

Expand Down Expand Up @@ -1123,7 +1123,7 @@ def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]:
else:
return self._downstream_task_ids

def get_direct_relatives(self, upstream: bool = False) -> List[str]:
def get_direct_relatives(self, upstream: bool = False) -> List["BaseOperator"]:
"""
Get list of the direct relatives to the current task, upstream or
downstream.
Expand Down
34 changes: 17 additions & 17 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def __init__(
self._dag_id = dag_id
self._full_filepath = full_filepath if full_filepath else ''
self._concurrency = concurrency
self._pickle_id = None
self._pickle_id: Optional[int] = None

self._description = description
# set file location to caller source path
Expand Down Expand Up @@ -304,7 +304,7 @@ def __init__(
self.dagrun_timeout = dagrun_timeout
self.sla_miss_callback = sla_miss_callback
if default_view in DEFAULT_VIEW_PRESETS:
self._default_view = default_view
self._default_view: str = default_view
else:
raise AirflowException(f'Invalid values of dag.default_view: only support '
f'{DEFAULT_VIEW_PRESETS}, but get {default_view}')
Expand Down Expand Up @@ -507,27 +507,27 @@ def get_last_dagrun(self, session=None, include_externally_triggered=False):
include_externally_triggered=include_externally_triggered)

@property
def dag_id(self):
def dag_id(self) -> str:
return self._dag_id

@dag_id.setter
def dag_id(self, value):
def dag_id(self, value: str) -> None:
self._dag_id = value

@property
def full_filepath(self):
def full_filepath(self) -> str:
return self._full_filepath

@full_filepath.setter
def full_filepath(self, value):
def full_filepath(self, value) -> None:
self._full_filepath = value

@property
def concurrency(self):
def concurrency(self) -> int:
return self._concurrency

@concurrency.setter
def concurrency(self, value):
def concurrency(self, value: int):
self._concurrency = value

@property
Expand All @@ -539,23 +539,23 @@ def access_control(self, value):
self._access_control = value

@property
def description(self):
def description(self) -> Optional[str]:
return self._description

@property
def default_view(self):
def default_view(self) -> str:
return self._default_view

@property
def pickle_id(self):
def pickle_id(self) -> Optional[int]:
return self._pickle_id

@pickle_id.setter
def pickle_id(self, value):
def pickle_id(self, value: int) -> None:
self._pickle_id = value

@property
def tasks(self):
def tasks(self) -> List[BaseOperator]:
return list(self.task_dict.values())

@tasks.setter
Expand All @@ -564,7 +564,7 @@ def tasks(self, val):
'DAG.tasks can not be modified. Use dag.add_task() instead.')

@property
def task_ids(self):
def task_ids(self) -> List[str]:
return list(self.task_dict.keys())

@property
Expand Down Expand Up @@ -1264,10 +1264,10 @@ def sub_dag(self, task_regex, include_downstream=False,

return dag

def has_task(self, task_id):
def has_task(self, task_id: str):
return task_id in (t.task_id for t in self.tasks)

def get_task(self, task_id, include_subdags=False):
def get_task(self, task_id: str, include_subdags: bool = False) -> BaseOperator:
if task_id in self.task_dict:
return self.task_dict[task_id]
if include_subdags:
Expand All @@ -1291,7 +1291,7 @@ def pickle_info(self):
return d

@provide_session
def pickle(self, session=None):
def pickle(self, session=None) -> DagPickle:
dag = session.query(
DagModel).filter(DagModel.dag_id == self.dag_id).first()
dp = None
Expand Down
20 changes: 11 additions & 9 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import textwrap
import zipfile
from datetime import datetime, timedelta
from typing import List, NamedTuple
from typing import Dict, List, NamedTuple, Optional

from croniter import CroniterBadCronError, CroniterBadDateError, CroniterNotAlphaError, croniter
from tabulate import tabulate
Expand Down Expand Up @@ -79,18 +79,20 @@ class DagBag(BaseDagBag, LoggingMixin):

def __init__(
self,
dag_folder=None,
include_examples=conf.getboolean('core', 'LOAD_EXAMPLES'),
safe_mode=conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'),
store_serialized_dags=False,
dag_folder: Optional[str] = None,
include_examples: bool = conf.getboolean('core', 'LOAD_EXAMPLES'),
safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'),
store_serialized_dags: bool = False,
):
# Avoid circular import
from airflow.models.dag import DAG
super().__init__()
dag_folder = dag_folder or settings.DAGS_FOLDER
self.dag_folder = dag_folder
self.dags = {}
self.dags: Dict[str, DAG] = {}
# the file's last modified timestamp when we last read it
self.file_last_changed = {}
self.import_errors = {}
self.file_last_changed: Dict[str, datetime] = {}
self.import_errors: Dict[str, str] = {}
self.has_logged = False
self.store_serialized_dags = store_serialized_dags

Expand All @@ -99,7 +101,7 @@ def __init__(
include_examples=include_examples,
safe_mode=safe_mode)

def size(self):
def size(self) -> int:
"""
:return: the amount of dags contained in this dagbag
"""
Expand Down
24 changes: 17 additions & 7 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import List, Optional, Tuple, Union, cast
from typing import Any, List, Optional, Tuple, Union, cast

from sqlalchemy import (
Boolean, Column, DateTime, Index, Integer, PickleType, String, UniqueConstraint, and_, func, or_,
Expand Down Expand Up @@ -66,8 +66,17 @@ class DagRun(Base, LoggingMixin):
UniqueConstraint('dag_id', 'run_id'),
)

def __init__(self, dag_id=None, run_id=None, execution_date=None, start_date=None, external_trigger=None,
conf=None, state=None, run_type=None):
def __init__(
self,
dag_id: Optional[str] = None,
run_id: Optional[str] = None,
execution_date: Optional[datetime] = None,
start_date: Optional[datetime] = None,
external_trigger: Optional[bool] = None,
conf: Optional[Any] = None,
state: Optional[str] = None,
run_type: Optional[str] = None
):
self.dag_id = dag_id
self.run_id = run_id
self.execution_date = execution_date
Expand Down Expand Up @@ -131,8 +140,9 @@ def find(
no_backfills: Optional[bool] = False,
run_type: Optional[DagRunType] = None,
session: Session = None,
execution_start_date=None, execution_end_date=None
):
execution_start_date: Optional[datetime] = None,
execution_end_date: Optional[datetime] = None
) -> List["DagRun"]:
"""
Returns a set of dag runs for the given search criteria.

Expand Down Expand Up @@ -281,7 +291,7 @@ def get_previous_scheduled_dagrun(self, session=None):
).first()

@provide_session
def update_state(self, session=None):
def update_state(self, session=None) -> List[TI]:
"""
Determines the overall state of the DagRun based on the state
of its TaskInstances.
Expand All @@ -291,7 +301,7 @@ def update_state(self, session=None):
"""

dag = self.get_dag()
ready_tis = []
ready_tis: List[TI] = []
tis = [ti for ti in self.get_task_instances(session=session,
state=State.task_states + (State.SHUTDOWN,))]
self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def queued_slots(self, session: Session):
) or 0

@provide_session
def open_slots(self, session: Session):
def open_slots(self, session: Session) -> float:
"""
Get the number of slots open at the moment.

Expand Down
14 changes: 6 additions & 8 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def generate_command(dag_id: str, # pylint: disable=too-many-arguments
ignore_task_deps: Optional[bool] = False,
ignore_ti_state: Optional[bool] = False,
local: Optional[bool] = False,
pickle_id: Optional[str] = None,
pickle_id: Optional[int] = None,
file_path: Optional[str] = None,
raw: Optional[bool] = False,
job_id: Optional[str] = None,
Expand Down Expand Up @@ -372,7 +372,7 @@ def generate_command(dag_id: str, # pylint: disable=too-many-arguments
:type local: Optional[bool]
:param pickle_id: If the DAG was serialized to the DB, the ID
associated with the pickled DAG
:type pickle_id: Optional[str]
:type pickle_id: Optional[int]
:param file_path: path to the file containing the DAG definition
:type file_path: Optional[str]
:param raw: raw mode (needs more details)
Expand All @@ -391,7 +391,7 @@ def generate_command(dag_id: str, # pylint: disable=too-many-arguments
if mark_success:
cmd.extend(["--mark-success"])
if pickle_id:
cmd.extend(["--pickle", pickle_id])
cmd.extend(["--pickle", str(pickle_id)])
if job_id:
cmd.extend(["--job-id", str(job_id)])
if ignore_all_deps:
Expand Down Expand Up @@ -573,7 +573,7 @@ def key(self) -> TaskInstanceKey:
return TaskInstanceKey(self.dag_id, self.task_id, self.execution_date, self.try_number)

@provide_session
def set_state(self, state, session=None, commit=True):
def set_state(self, state: str, session=None, commit: bool = True):
"""
Set TaskInstance state

Expand Down Expand Up @@ -1779,9 +1779,7 @@ def __init__(self, ti: TaskInstance):
self._run_as_user: Optional[str] = None
if hasattr(ti, 'run_as_user'):
self._run_as_user = ti.run_as_user
self._pool: Optional[str] = None
if hasattr(ti, 'pool'):
self._pool = ti.pool
self._pool: str = ti.pool
self._priority_weight: Optional[int] = None
if hasattr(ti, 'priority_weight'):
self._priority_weight = ti.priority_weight
Expand Down Expand Up @@ -1818,7 +1816,7 @@ def state(self) -> str:
return self._state

@property
def pool(self) -> Any:
def pool(self) -> str:
return self._pool

@property
Expand Down
9 changes: 6 additions & 3 deletions airflow/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import string
import textwrap
from functools import wraps
from typing import Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional

from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException, InvalidStatsNameException
Expand Down Expand Up @@ -255,5 +255,8 @@ def get_constant_tags(self):
return tags


class Stats(metaclass=_Stats): # noqa: D101
pass
if TYPE_CHECKING:
Stats: StatsLogger
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we have magic with metaclasses, which is not very liked by mypy.

else:
class Stats(metaclass=_Stats): # noqa: D101
pass
14 changes: 7 additions & 7 deletions airflow/ti_deps/dep_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ class DepContext:
def __init__(
self,
deps=None,
flag_upstream_failed=False,
ignore_all_deps=False,
ignore_depends_on_past=False,
ignore_in_retry_period=False,
ignore_in_reschedule_period=False,
ignore_task_deps=False,
ignore_ti_state=False,
flag_upstream_failed: bool = False,
ignore_all_deps: bool = False,
ignore_depends_on_past: bool = False,
ignore_in_retry_period: bool = False,
ignore_in_reschedule_period: bool = False,
ignore_task_deps: bool = False,
ignore_ti_state: bool = False,
finished_tasks=None):
self.deps = deps or set()
self.flag_upstream_failed = flag_upstream_failed
Expand Down
Loading