From 5a4b7a38eb7838049e4b7e9ee71cc792b3871eae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= <6774676+eumiro@users.noreply.github.com> Date: Wed, 30 Aug 2023 21:07:26 +0200 Subject: [PATCH] Refactor unneeded 'continue' jumps around the repo --- airflow/decorators/base.py | 14 ++++++------ airflow/executors/debug_executor.py | 9 +++----- airflow/plugins_manager.py | 15 ++++++------- airflow/providers_manager.py | 25 ++++++++++----------- airflow/serialization/serialized_objects.py | 11 ++++----- airflow/template/templater.py | 4 +--- 6 files changed, 34 insertions(+), 44 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index d601483a3a4ac..31a354e02018d 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -112,10 +112,11 @@ def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]) -> kwargs_left = kwargs.copy() for arg_name in self._mappable_function_argument_names: value = kwargs_left.pop(arg_name, NOTSET) - if func != "expand" or value is NOTSET or is_mappable(value): - continue - tname = type(value).__name__ - raise ValueError(f"expand() got an unexpected type {tname!r} for keyword argument {arg_name!r}") + if func == "expand" and value is not NOTSET and not is_mappable(value): + tname = type(value).__name__ + raise ValueError( + f"expand() got an unexpected type {tname!r} for keyword argument {arg_name!r}" + ) if len(kwargs_left) == 1: raise TypeError(f"{func}() got an unexpected keyword argument {next(iter(kwargs_left))!r}") elif kwargs_left: @@ -157,9 +158,8 @@ def _find_id_suffixes(dag: DAG) -> Iterator[int]: prefix = re2.split(r"__\d+$", tg_task_id)[0] for task_id in dag.task_ids: match = re2.match(rf"^{prefix}__(\d+)$", task_id) - if match is None: - continue - yield int(match.group(1)) + if match: + yield int(match.group(1)) yield 0 # Default if there's no matching task ID. core = re2.split(r"__\d+$", task_id)[0] diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py index 4ecebdff8b4f8..be2b657b7556e 100644 --- a/airflow/executors/debug_executor.py +++ b/airflow/executors/debug_executor.py @@ -71,15 +71,12 @@ def sync(self) -> None: self.log.info("Setting %s to %s", ti.key, TaskInstanceState.UPSTREAM_FAILED) ti.set_state(TaskInstanceState.UPSTREAM_FAILED) self.change_state(ti.key, TaskInstanceState.UPSTREAM_FAILED) - continue - - if self._terminated.is_set(): + elif self._terminated.is_set(): self.log.info("Executor is terminated! Stopping %s to %s", ti.key, TaskInstanceState.FAILED) ti.set_state(TaskInstanceState.FAILED) self.change_state(ti.key, TaskInstanceState.FAILED) - continue - - task_succeeded = self._run_task(ti) + else: + task_succeeded = self._run_task(ti) def _run_task(self, ti: TaskInstance) -> bool: self.log.debug("Executing task: %s", ti) diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index 0970e2ca3c10d..9e724aa9b51c8 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -26,6 +26,7 @@ import os import sys import types +from pathlib import Path from typing import TYPE_CHECKING, Any, Iterable from airflow import settings @@ -251,11 +252,10 @@ def load_plugins_from_plugin_directory(): log.debug("Loading plugins from directory: %s", settings.PLUGINS_FOLDER) for file_path in find_path_from_directory(settings.PLUGINS_FOLDER, ".airflowignore"): - if not os.path.isfile(file_path): - continue - mod_name, file_ext = os.path.splitext(os.path.split(file_path)[-1]) - if file_ext != ".py": + path = Path(file_path) + if not path.is_file() or path.suffix != ".py": continue + mod_name = path.stem try: loader = importlib.machinery.SourceFileLoader(mod_name, file_path) @@ -285,13 +285,12 @@ def load_providers_plugins(): try: plugin_instance = import_string(plugin.plugin_class) - if not is_valid_plugin(plugin_instance): + if is_valid_plugin(plugin_instance): + register_plugin(plugin_instance) + else: log.warning("Plugin %s is not a valid plugin", plugin.name) - continue - register_plugin(plugin_instance) except ImportError: log.exception("Failed to load plugin %s from class name %s", plugin.name, plugin.plugin_class) - continue def make_module(name: str, objects: list[Any]): diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index 2b62f796ce64e..a5502bce416ee 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -634,10 +634,9 @@ def _discover_all_airflow_builtin_providers_from_local_sources(self) -> None: # The same path can appear in the __path__ twice, under non-normalized paths (ie. # /path/to/repo/airflow/providers and /path/to/repo/./airflow/providers) path = os.path.realpath(path) - if path in seen: - continue - seen.add(path) - self._add_provider_info_from_local_source_files_on_path(path) + if path not in seen: + seen.add(path) + self._add_provider_info_from_local_source_files_on_path(path) except Exception as e: log.warning(f"Error when loading 'provider.yaml' files from {path} airflow sources: {e}") @@ -1004,15 +1003,15 @@ def _add_widgets(self, package_name: str, hook_class: type, widgets: dict[str, A hook_class.__name__, ) # In case of inherited hooks this might be happening several times - continue - self._connection_form_widgets[prefixed_field_name] = ConnectionFormWidgetInfo( - hook_class.__name__, - package_name, - field, - field_identifier, - hasattr(field.field_class.widget, "input_type") - and field.field_class.widget.input_type == "password", - ) + else: + self._connection_form_widgets[prefixed_field_name] = ConnectionFormWidgetInfo( + hook_class.__name__, + package_name, + field, + field_identifier, + hasattr(field.field_class.widget, "input_type") + and field.field_class.widget.input_type == "password", + ) def _add_customized_fields(self, package_name: str, hook_class: type, customized_fields: dict): try: diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 7433e706e9fbc..a7a712cf11ee0 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -660,13 +660,10 @@ def is_serialized(val): return False for attr in attrs: - if attr not in param_dict: - continue - val = param_dict[attr] - if is_serialized(val): - deserialized_val = cls.deserialize(param_dict[attr]) - kwargs[attr] = deserialized_val - else: + if attr in param_dict: + val = param_dict[attr] + if is_serialized(val): + val = cls.deserialize(val) kwargs[attr] = val return class_(**kwargs) diff --git a/airflow/template/templater.py b/airflow/template/templater.py index 07aead85800d0..9cb6a312ad3ac 100644 --- a/airflow/template/templater.py +++ b/airflow/template/templater.py @@ -68,9 +68,7 @@ def resolve_template_files(self) -> None: if self.template_ext: for field in self.template_fields: content = getattr(self, field, None) - if content is None: - continue - elif isinstance(content, str) and content.endswith(tuple(self.template_ext)): + if isinstance(content, str) and content.endswith(tuple(self.template_ext)): env = self.get_template_env() try: setattr(self, field, env.loader.get_source(env, content)[0]) # type: ignore