diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py b/airflow-core/src/airflow/api_fastapi/auth/tokens.py index dfcef7318313a..f80b7513d8ee8 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py +++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py @@ -68,7 +68,7 @@ def key_to_jwk_dict(key: AllowedKeys, kid: str | None = None): from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey from jwt.algorithms import OKPAlgorithm, RSAAlgorithm - if isinstance(key, RSAPrivateKey | Ed25519PrivateKey): + if isinstance(key, (RSAPrivateKey, Ed25519PrivateKey)): key = key.public_key() if isinstance(key, RSAPublicKey): diff --git a/airflow-core/src/airflow/api_fastapi/common/parameters.py b/airflow-core/src/airflow/api_fastapi/common/parameters.py index c6e3bc945b020..14ab5efa8ef4d 100644 --- a/airflow-core/src/airflow/api_fastapi/common/parameters.py +++ b/airflow-core/src/airflow/api_fastapi/common/parameters.py @@ -264,7 +264,7 @@ def __init__( self.filter_option: FilterOptionEnum = filter_option def to_orm(self, select: Select) -> Select: - if isinstance(self.value, list | str) and not self.value and self.skip_none: + if isinstance(self.value, (list, str)) and not self.value and self.skip_none: return select if self.value is None and self.skip_none: return select diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/public/variables.py b/airflow-core/src/airflow/api_fastapi/core_api/services/public/variables.py index 00bbf543288fe..0208ea1a0a5f7 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/public/variables.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/public/variables.py @@ -64,7 +64,7 @@ def handle_bulk_create(self, action: BulkCreateAction, results: BulkActionRespon for variable in action.entities: if variable.key in create_keys: - should_serialize_json = isinstance(variable.value, dict | list) + should_serialize_json = isinstance(variable.value, (dict, list)) Variable.set( key=variable.key, value=variable.value, diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py index 9c356e9c3ad84..a69cafb7bbbda 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py @@ -133,7 +133,7 @@ def get_child_task_map(parent_task_id: str, task_node_map: dict[str, dict[str, A def _count_tis(node: int | MappedTaskGroup | MappedOperator, run_id: str, session: SessionDep) -> int: - if not isinstance(node, MappedTaskGroup | MappedOperator): + if not isinstance(node, (MappedTaskGroup, MappedOperator)): return node with contextlib.suppress(NotFullyPopulated, NotMapped): return DBBaseOperator.get_mapped_ti_count(node, run_id=run_id, session=session) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 0a57e604a6eb6..e22d0a5f34d0d 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -434,7 +434,7 @@ def _create_ti_state_update_query_and_update_state( dag_bag: DagBagDep, dag_id: str, ) -> tuple[Update, TaskInstanceState]: - if isinstance(ti_patch_payload, TITerminalStatePayload | TIRetryStatePayload | TISuccessStatePayload): + if isinstance(ti_patch_payload, (TITerminalStatePayload, TIRetryStatePayload, TISuccessStatePayload)): ti = session.get(TI, ti_id_str) updated_state = ti_patch_payload.state query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) diff --git a/airflow-core/src/airflow/cli/commands/plugins_command.py b/airflow-core/src/airflow/cli/commands/plugins_command.py index f26546439e940..29dd75674afe0 100644 --- a/airflow-core/src/airflow/cli/commands/plugins_command.py +++ b/airflow-core/src/airflow/cli/commands/plugins_command.py @@ -27,7 +27,7 @@ def _get_name(class_like_object) -> str: - if isinstance(class_like_object, str | PluginsDirectorySource): + if isinstance(class_like_object, (str, PluginsDirectorySource)): return str(class_like_object) if inspect.isclass(class_like_object): return class_like_object.__name__ diff --git a/airflow-core/src/airflow/cli/simple_table.py b/airflow-core/src/airflow/cli/simple_table.py index 01025f3d64762..53125e3f74830 100644 --- a/airflow-core/src/airflow/cli/simple_table.py +++ b/airflow-core/src/airflow/cli/simple_table.py @@ -84,7 +84,7 @@ def print_as_plain_table(self, data: list[dict]): print(output) def _normalize_data(self, value: Any, output: str) -> list | str | dict | None: - if isinstance(value, tuple | list): + if isinstance(value, (tuple, list)): if output == "table": return ",".join(str(self._normalize_data(x, output)) for x in value) return [self._normalize_data(x, output) for x in value] diff --git a/airflow-core/src/airflow/exceptions.py b/airflow-core/src/airflow/exceptions.py index 2d12b0baabf32..045f9647ade76 100644 --- a/airflow-core/src/airflow/exceptions.py +++ b/airflow-core/src/airflow/exceptions.py @@ -476,7 +476,7 @@ def __init__( self.kwargs = kwargs self.timeout: timedelta | None # Check timeout type at runtime - if isinstance(timeout, int | float): + if isinstance(timeout, (int, float)): self.timeout = timedelta(seconds=timeout) else: self.timeout = timeout diff --git a/airflow-core/src/airflow/models/asset.py b/airflow-core/src/airflow/models/asset.py index 79cf8ef28a622..a345ebe7995d3 100644 --- a/airflow-core/src/airflow/models/asset.py +++ b/airflow-core/src/airflow/models/asset.py @@ -209,7 +209,7 @@ def __hash__(self): def __eq__(self, other: object) -> bool: from airflow.sdk.definitions.asset import AssetAlias - if isinstance(other, self.__class__ | AssetAlias): + if isinstance(other, (self.__class__, AssetAlias)): return self.name == other.name return NotImplemented @@ -306,7 +306,7 @@ def __init__(self, name: str = "", uri: str = "", **kwargs): def __eq__(self, other: object) -> bool: from airflow.sdk.definitions.asset import Asset - if isinstance(other, self.__class__ | Asset): + if isinstance(other, (self.__class__, Asset)): return self.name == other.name and self.uri == other.uri return NotImplemented diff --git a/airflow-core/src/airflow/models/dag.py b/airflow-core/src/airflow/models/dag.py index a74a68ac2c493..2fe689e6168d5 100644 --- a/airflow-core/src/airflow/models/dag.py +++ b/airflow-core/src/airflow/models/dag.py @@ -463,7 +463,7 @@ def _upgrade_outdated_dag_access_control(access_control=None): for role, perms in access_control.items(): if packaging_version.parse(FAB_VERSION) >= packaging_version.parse("1.3.0"): updated_access_control[role] = updated_access_control.get(role, {}) - if isinstance(perms, set | list): + if isinstance(perms, (set, list)): # Support for old-style access_control where only the actions are specified updated_access_control[role][permissions.RESOURCE_DAG] = set(perms) else: @@ -541,7 +541,7 @@ def infer_automated_data_interval(self, logical_date: datetime) -> DataInterval: :meta private: """ timetable_type = type(self.timetable) - if issubclass(timetable_type, NullTimetable | OnceTimetable | AssetTriggeredTimetable): + if issubclass(timetable_type, (NullTimetable, OnceTimetable, AssetTriggeredTimetable)): return DataInterval.exact(timezone.coerce_datetime(logical_date)) start = timezone.coerce_datetime(logical_date) if issubclass(timetable_type, CronDataIntervalTimetable): @@ -959,7 +959,7 @@ def _get_task_instances( tis = tis.where(DagRun.logical_date <= end_date) if state: - if isinstance(state, str | TaskInstanceState): + if isinstance(state, (str, TaskInstanceState)): tis = tis.where(TaskInstance.state == state) elif len(state) == 1: tis = tis.where(TaskInstance.state == state[0]) diff --git a/airflow-core/src/airflow/models/dagbag.py b/airflow-core/src/airflow/models/dagbag.py index df13c38581d76..83fc47444c179 100644 --- a/airflow-core/src/airflow/models/dagbag.py +++ b/airflow-core/src/airflow/models/dagbag.py @@ -448,7 +448,7 @@ def parse(mod_name, filepath): dagbag_import_timeout = settings.get_dagbag_import_timeout(filepath) - if not isinstance(dagbag_import_timeout, int | float): + if not isinstance(dagbag_import_timeout, (int, float)): raise TypeError( f"Value ({dagbag_import_timeout}) from get_dagbag_import_timeout must be int or float" ) @@ -520,7 +520,7 @@ def _process_modules(self, filepath, mods, file_last_changed_on_disk): from airflow.sdk import DAG as SDKDAG from airflow.sdk.definitions._internal.contextmanager import DagContext - top_level_dags = {(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG | SDKDAG)} + top_level_dags = {(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, (DAG, SDKDAG))} top_level_dags.update(DagContext.autoregistered_dags) diff --git a/airflow-core/src/airflow/models/expandinput.py b/airflow-core/src/airflow/models/expandinput.py index 46b70be79a954..803a8fe294c6d 100644 --- a/airflow-core/src/airflow/models/expandinput.py +++ b/airflow-core/src/airflow/models/expandinput.py @@ -54,7 +54,7 @@ def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | SchedulerXComArg]: from airflow.models.xcom_arg import SchedulerXComArg - return isinstance(v, MappedArgument | SchedulerXComArg) + return isinstance(v, (MappedArgument, SchedulerXComArg)) @attrs.define diff --git a/airflow-core/src/airflow/models/taskmap.py b/airflow-core/src/airflow/models/taskmap.py index ff3b94a84dacf..f0fd4c0231b70 100644 --- a/airflow-core/src/airflow/models/taskmap.py +++ b/airflow-core/src/airflow/models/taskmap.py @@ -137,7 +137,7 @@ def expand_mapped_task(cls, task, run_id: str, *, session: Session) -> tuple[Seq from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.settings import task_instance_mutation_hook - if not isinstance(task, BaseOperator | MappedOperator): + if not isinstance(task, (BaseOperator, MappedOperator)): raise RuntimeError( f"cannot expand unrecognized operator type {type(task).__module__}.{type(task).__name__}" ) diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index 33676d58bec06..1db0a91236199 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -744,7 +744,7 @@ def serialize( return cls._encode(var.timestamp(), type_=DAT.DATETIME) elif isinstance(var, datetime.timedelta): return cls._encode(var.total_seconds(), type_=DAT.TIMEDELTA) - elif isinstance(var, Timezone | FixedTimezone): + elif isinstance(var, (Timezone, FixedTimezone)): return cls._encode(encode_timezone(var), type_=DAT.TIMEZONE) elif isinstance(var, relativedelta.relativedelta): return cls._encode(encode_relativedelta(var), type_=DAT.RELATIVEDELTA) @@ -753,7 +753,7 @@ def serialize( var._asdict(), type_=DAT.TASK_INSTANCE_KEY, ) - elif isinstance(var, AirflowException | TaskDeferred) and hasattr(var, "serialize"): + elif isinstance(var, (AirflowException, TaskDeferred)) and hasattr(var, "serialize"): exc_cls_name, args, kwargs = var.serialize() return cls._encode( cls.serialize( @@ -762,7 +762,7 @@ def serialize( ), type_=DAT.AIRFLOW_EXC_SER, ) - elif isinstance(var, KeyError | AttributeError): + elif isinstance(var, (KeyError, AttributeError)): return cls._encode( cls.serialize( { diff --git a/airflow-core/src/airflow/serialization/serializers/datetime.py b/airflow-core/src/airflow/serialization/serializers/datetime.py index a99c690f3cb72..69058b8c02a8b 100644 --- a/airflow-core/src/airflow/serialization/serializers/datetime.py +++ b/airflow-core/src/airflow/serialization/serializers/datetime.py @@ -92,7 +92,7 @@ def deserialize(classname: str, version: int, data: dict | str) -> datetime.date if classname == qualname(DateTime) and isinstance(data, dict): return DateTime.fromtimestamp(float(data[TIMESTAMP]), tz=tz) - if classname == qualname(datetime.timedelta) and isinstance(data, str | float): + if classname == qualname(datetime.timedelta) and isinstance(data, (str, float)): return datetime.timedelta(seconds=float(data)) if classname == qualname(datetime.date) and isinstance(data, str): diff --git a/airflow-core/src/airflow/serialization/serializers/kubernetes.py b/airflow-core/src/airflow/serialization/serializers/kubernetes.py index 908bd0dc29e17..faa2312ac7a81 100644 --- a/airflow-core/src/airflow/serialization/serializers/kubernetes.py +++ b/airflow-core/src/airflow/serialization/serializers/kubernetes.py @@ -43,7 +43,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]: if not k8s: return "", "", 0, False - if isinstance(o, k8s.V1Pod | k8s.V1ResourceRequirements): + if isinstance(o, (k8s.V1Pod, k8s.V1ResourceRequirements)): from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator # We're running this in an except block, so we don't want it to fail diff --git a/airflow-core/src/airflow/serialization/serializers/numpy.py b/airflow-core/src/airflow/serialization/serializers/numpy.py index 6e317b47d3885..1e32f90a081aa 100644 --- a/airflow-core/src/airflow/serialization/serializers/numpy.py +++ b/airflow-core/src/airflow/serialization/serializers/numpy.py @@ -72,7 +72,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]: if isinstance(o, np.bool_): return bool(o), name, __version__, True - if isinstance(o, np.float16 | np.float32 | np.float64 | np.complex64 | np.complex128): + if isinstance(o, (np.float16, np.float32, np.float64, np.complex64, np.complex128)): return float(o), name, __version__, True return "", "", 0, False diff --git a/airflow-core/src/airflow/serialization/serializers/timezone.py b/airflow-core/src/airflow/serialization/serializers/timezone.py index ca2f25ed59010..9f2ef7cef65ac 100644 --- a/airflow-core/src/airflow/serialization/serializers/timezone.py +++ b/airflow-core/src/airflow/serialization/serializers/timezone.py @@ -70,7 +70,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]: def deserialize(classname: str, version: int, data: object) -> Any: from airflow.utils.timezone import parse_timezone - if not isinstance(data, str | int): + if not isinstance(data, (str, int)): raise TypeError(f"{data} is not of type int or str but of {type(data)}") if version > __version__: diff --git a/airflow-core/src/airflow/utils/dot_renderer.py b/airflow-core/src/airflow/utils/dot_renderer.py index f54ea23935637..fc1685b68b61f 100644 --- a/airflow-core/src/airflow/utils/dot_renderer.py +++ b/airflow-core/src/airflow/utils/dot_renderer.py @@ -134,7 +134,7 @@ def _draw_nodes( node: DependencyMixin, parent_graph: graphviz.Digraph, states_by_task_id: dict[str, str] | None ) -> None: """Draw the node and its children on the given parent_graph recursively.""" - if isinstance(node, BaseOperator | MappedOperator): + if isinstance(node, (BaseOperator, MappedOperator)): _draw_task(node, parent_graph, states_by_task_id) else: if not isinstance(node, TaskGroup): diff --git a/airflow-core/src/airflow/utils/helpers.py b/airflow-core/src/airflow/utils/helpers.py index aa111cf07733b..00d865ff9fcf5 100644 --- a/airflow-core/src/airflow/utils/helpers.py +++ b/airflow-core/src/airflow/utils/helpers.py @@ -300,7 +300,7 @@ def is_empty(x): for k, v in val.items(): if is_empty(v): continue - if isinstance(v, list | dict): + if isinstance(v, (list, dict)): new_val = prune_dict(v, mode=mode) if not is_empty(new_val): new_dict[k] = new_val @@ -312,7 +312,7 @@ def is_empty(x): for v in val: if is_empty(v): continue - if isinstance(v, list | dict): + if isinstance(v, (list, dict)): new_val = prune_dict(v, mode=mode) if not is_empty(new_val): new_list.append(new_val) diff --git a/airflow-core/src/airflow/utils/log/colored_log.py b/airflow-core/src/airflow/utils/log/colored_log.py index fa3c9d52c2609..bd9763ce4553a 100644 --- a/airflow-core/src/airflow/utils/log/colored_log.py +++ b/airflow-core/src/airflow/utils/log/colored_log.py @@ -58,7 +58,7 @@ def __init__(self, *args, **kwargs): @staticmethod def _color_arg(arg: Any) -> str | float | int: - if isinstance(arg, int | float): + if isinstance(arg, (int, float)): # In case of %d or %f formatting return arg return BOLD_ON + str(arg) + BOLD_OFF @@ -69,7 +69,7 @@ def _count_number_of_arguments_in_message(record: LogRecord) -> int: return len(matches) if matches else 0 def _color_record_args(self, record: LogRecord) -> LogRecord: - if isinstance(record.args, tuple | list): + if isinstance(record.args, (tuple, list)): record.args = tuple(self._color_arg(arg) for arg in record.args) elif isinstance(record.args, dict): if self._count_number_of_arguments_in_message(record) > 1: diff --git a/airflow-core/src/airflow/utils/setup_teardown.py b/airflow-core/src/airflow/utils/setup_teardown.py index d2273c0c3f705..b62e3ab298b40 100644 --- a/airflow-core/src/airflow/utils/setup_teardown.py +++ b/airflow-core/src/airflow/utils/setup_teardown.py @@ -114,7 +114,7 @@ def set_dependency( new_task: AbstractOperator | list[AbstractOperator], upstream=True, ): - if isinstance(new_task, list | tuple): + if isinstance(new_task, (list, tuple)): for task in new_task: cls._set_dependency(task, receiving_task, upstream) else: diff --git a/airflow-core/src/airflow/utils/sqlalchemy.py b/airflow-core/src/airflow/utils/sqlalchemy.py index a75edafebb567..d08307db14bf5 100644 --- a/airflow-core/src/airflow/utils/sqlalchemy.py +++ b/airflow-core/src/airflow/utils/sqlalchemy.py @@ -165,13 +165,13 @@ def sanitize_for_serialization(obj: V1Pod): """ if obj is None: return None - if isinstance(obj, float | bool | bytes | str | int): + if isinstance(obj, (float, bool, bytes, str, int)): return obj if isinstance(obj, list): return [sanitize_for_serialization(sub_obj) for sub_obj in obj] if isinstance(obj, tuple): return tuple(sanitize_for_serialization(sub_obj) for sub_obj in obj) - if isinstance(obj, datetime.datetime | datetime.date): + if isinstance(obj, (datetime.datetime, datetime.date)): return obj.isoformat() if isinstance(obj, dict): diff --git a/airflow-core/tests/unit/always/test_project_structure.py b/airflow-core/tests/unit/always/test_project_structure.py index 2a818c13ac4dd..ae29a5abc6b25 100644 --- a/airflow-core/tests/unit/always/test_project_structure.py +++ b/airflow-core/tests/unit/always/test_project_structure.py @@ -267,7 +267,7 @@ def get_imports_from_file(filepath: str): doc_node = ast.parse(content, filepath) import_names: set[str] = set() for current_node in ast.walk(doc_node): - if not isinstance(current_node, ast.Import | ast.ImportFrom): + if not isinstance(current_node, (ast.Import, ast.ImportFrom)): continue for alias in current_node.names: name = alias.name diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py index 34874c1fcf13f..b1270b7365efc 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py @@ -1095,7 +1095,7 @@ def test_bulk_create_entity_serialization( response = test_client.patch("/variables", json=actions) assert response.status_code == 200 - if isinstance(entity_value, dict | list): + if isinstance(entity_value, (dict, list)): retrieved_value_deserialized = Variable.get(entity_key, deserialize_json=True) assert retrieved_value_deserialized == entity_value retrieved_value_raw_string = Variable.get(entity_key, deserialize_json=False) diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index e78c40d3797d4..1a3b55f0a892e 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -404,7 +404,7 @@ def collect_dags(dag_folder=None): "providers/*/*/tests/system/*/*/", ] else: - if isinstance(dag_folder, list | tuple): + if isinstance(dag_folder, (list, tuple)): patterns = dag_folder else: patterns = [dag_folder] @@ -723,7 +723,7 @@ def validate_deserialized_task( from airflow.sdk.definitions.mappedoperator import MappedOperator assert not isinstance(task, SerializedBaseOperator) - assert isinstance(task, BaseOperator | MappedOperator) + assert isinstance(task, (BaseOperator, MappedOperator)) # Every task should have a task_group property -- even if it's the DAG's root task group assert serialized_task.task_group diff --git a/dev/breeze/src/airflow_breeze/utils/parallel.py b/dev/breeze/src/airflow_breeze/utils/parallel.py index 68dd0b30967a4..1d6af2d174865 100644 --- a/dev/breeze/src/airflow_breeze/utils/parallel.py +++ b/dev/breeze/src/airflow_breeze/utils/parallel.py @@ -217,7 +217,7 @@ def bytes2human(n): def get_printable_value(key: str, value: Any) -> str: if key == "percent": return f"{value} %" - if isinstance(value, int | float): + if isinstance(value, (int, float)): return bytes2human(value) return str(value) diff --git a/devel-common/src/sphinx_exts/operators_and_hooks_ref.py b/devel-common/src/sphinx_exts/operators_and_hooks_ref.py index 54d4fe86215ca..496a45db09ba5 100644 --- a/devel-common/src/sphinx_exts/operators_and_hooks_ref.py +++ b/devel-common/src/sphinx_exts/operators_and_hooks_ref.py @@ -287,7 +287,7 @@ def analyze_decorators(node, _file_path, object_type, _class_name=None): if isinstance(child, ast.ClassDef): analyze_decorators(child, file_path, object_type="class") deprecations.extend(_iter_module_for_deprecations(child, file_path, class_name=child.name)) - elif isinstance(child, ast.FunctionDef | ast.AsyncFunctionDef): + elif isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): analyze_decorators( child, file_path, _class_name=class_name, object_type="method" if class_name else "function" ) diff --git a/devel-common/src/sphinx_exts/providers_extensions.py b/devel-common/src/sphinx_exts/providers_extensions.py index 78a0a40df820a..0ea9148893400 100644 --- a/devel-common/src/sphinx_exts/providers_extensions.py +++ b/devel-common/src/sphinx_exts/providers_extensions.py @@ -140,7 +140,7 @@ def get_import_mappings(tree) -> dict[str, str]: """ imports = {} for node in ast.walk(tree): - if isinstance(node, ast.Import | ast.ImportFrom): + if isinstance(node, (ast.Import, ast.ImportFrom)): for alias in node.names: module_prefix = f"{node.module}." if hasattr(node, "module") and node.module else "" imports[alias.asname or alias.name] = f"{module_prefix}{alias.name}" diff --git a/devel-common/src/sphinx_exts/removemarktransform.py b/devel-common/src/sphinx_exts/removemarktransform.py index 9561baa9ff43c..bb65d026bce93 100644 --- a/devel-common/src/sphinx_exts/removemarktransform.py +++ b/devel-common/src/sphinx_exts/removemarktransform.py @@ -61,7 +61,7 @@ def is_pycode(node: nodes.literal_block) -> bool: if language == "guess": try: lexer = guess_lexer(node.rawsource) - return isinstance(lexer, PythonLexer | Python3Lexer) + return isinstance(lexer, (PythonLexer, Python3Lexer)) except Exception: pass diff --git a/devel-common/src/sphinx_exts/substitution_extensions.py b/devel-common/src/sphinx_exts/substitution_extensions.py index 5fe68e1d55106..faa9501ffee77 100644 --- a/devel-common/src/sphinx_exts/substitution_extensions.py +++ b/devel-common/src/sphinx_exts/substitution_extensions.py @@ -60,7 +60,7 @@ class SubstitutionCodeBlockTransform(SphinxTransform): def apply(self, **kwargs: Any) -> None: def condition(node): - return isinstance(node, nodes.literal_block | nodes.literal) + return isinstance(node, (nodes.literal_block, nodes.literal)) for node in self.document.traverse(condition): if _SUBSTITUTION_OPTION_NAME not in node: diff --git a/devel-common/src/tests_common/_internals/forbidden_warnings.py b/devel-common/src/tests_common/_internals/forbidden_warnings.py index 77e08c9320f49..6e231160e70b9 100644 --- a/devel-common/src/tests_common/_internals/forbidden_warnings.py +++ b/devel-common/src/tests_common/_internals/forbidden_warnings.py @@ -34,7 +34,7 @@ class ForbiddenWarningsPlugin: def __init__(self, config: pytest.Config, forbidden_warnings: tuple[str, ...]): # Set by a pytest_configure hook in conftest deprecations_ignore = config.inicfg["airflow_deprecations_ignore"] - if isinstance(deprecations_ignore, str | os.PathLike): + if isinstance(deprecations_ignore, (str, os.PathLike)): self.deprecations_ignore = [deprecations_ignore] else: self.deprecations_ignore = deprecations_ignore diff --git a/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py b/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py index b9bb25766be44..73b6b33e66992 100644 --- a/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py +++ b/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py @@ -306,8 +306,8 @@ def _validate_list_of_stringables(vals: Sequence[str | int | float]) -> bool: """ if ( vals is None - or not isinstance(vals, tuple | list) - or not all(isinstance(val, str | int | float) for val in vals) + or not isinstance(vals, (tuple, list)) + or not all(isinstance(val, (str, int, float)) for val in vals) ): raise ValueError("List of strings expected") return True @@ -322,7 +322,7 @@ def _validate_extra_conf(conf: dict[Any, Any]) -> bool: if conf: if not isinstance(conf, dict): raise ValueError("'conf' argument must be a dict") - if not all(isinstance(v, str | int) and v != "" for v in conf.values()): + if not all(isinstance(v, (str, int)) and v != "" for v in conf.values()): raise ValueError("'conf' values must be either strings or ints") return True diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py index a432534fab3c3..81cd7a0cb170f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py @@ -652,7 +652,7 @@ async def get_files_async( response = paginator.paginate(**params) async for page in response: if "Contents" in page: - keys.extend(k for k in page["Contents"] if isinstance(k.get("Size"), int | float)) + keys.extend(k for k in page["Contents"] if isinstance(k.get("Size"), (int, float))) return keys async def _list_keys_async( diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py index 275df5b69a907..00315fbddc370 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py @@ -29,7 +29,7 @@ def _get_message_attribute(o): return {"DataType": "Binary", "BinaryValue": o} if isinstance(o, str): return {"DataType": "String", "StringValue": o} - if isinstance(o, int | float): + if isinstance(o, (int, float)): return {"DataType": "Number", "StringValue": str(o)} if hasattr(o, "__iter__"): return {"DataType": "String.Array", "StringValue": json.dumps(o)} diff --git a/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py b/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py index 2eab684130c42..b5f6aaa12662b 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py @@ -55,7 +55,7 @@ def json_serialize_legacy(value: Any) -> str | None: :param value: the object to serialize :return: string representation of `value` if it is an instance of datetime or `None` otherwise """ - if isinstance(value, date | datetime): + if isinstance(value, (date, datetime)): return value.isoformat() return None diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py index 1757614101a19..f486084f34de0 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py @@ -536,7 +536,7 @@ def execute(self, context: Context): "Either keys or at least one of prefix, from_datetime, to_datetime should be set." ) - if isinstance(self.keys, list | str) and not self.keys: + if isinstance(self.keys, (list, str)) and not self.keys: return # handle case where dates are strings, specifically when sent as template fields and macros. if isinstance(self.to_datetime, str): diff --git a/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py b/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py index 79e6b966bc056..ef2ca7b4cb343 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py @@ -187,7 +187,7 @@ def __post_init__(self, conn: Connection | AwsConnectionWrapper | _ConnectionMet return if TYPE_CHECKING: - assert isinstance(conn, Connection | _ConnectionMetadata) + assert isinstance(conn, (Connection, _ConnectionMetadata)) # Assign attributes from AWS Connection self.conn_id = conn.conn_id diff --git a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py index 68912c30fe9d3..2c9f4bb7de609 100644 --- a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py +++ b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py @@ -439,8 +439,8 @@ def _validate_list_of_stringables(vals: Sequence[str | int | float]) -> bool: """ if ( vals is None - or not isinstance(vals, tuple | list) - or not all(isinstance(val, str | int | float) for val in vals) + or not isinstance(vals, (tuple, list)) + or not all(isinstance(val, (str, int, float)) for val in vals) ): raise ValueError("List of strings expected") return True @@ -456,7 +456,7 @@ def _validate_extra_conf(conf: dict[Any, Any]) -> bool: if conf: if not isinstance(conf, dict): raise ValueError("'conf' argument must be a dict") - if not all(isinstance(v, str | int) and v != "" for v in conf.values()): + if not all(isinstance(v, (str, int)) and v != "" for v in conf.values()): raise ValueError("'conf' values must be either strings or ints") return True @@ -825,8 +825,8 @@ def _validate_list_of_stringables(vals: Sequence[str | int | float]) -> bool: """ if ( vals is None - or not isinstance(vals, tuple | list) - or not all(isinstance(val, str | int | float) for val in vals) + or not isinstance(vals, (tuple, list)) + or not all(isinstance(val, (str, int, float)) for val in vals) ): raise ValueError("List of strings expected") return True @@ -842,6 +842,6 @@ def _validate_extra_conf(conf: dict[Any, Any]) -> bool: if conf: if not isinstance(conf, dict): raise ValueError("'conf' argument must be a dict") - if not all(isinstance(v, str | int) and v != "" for v in conf.values()): + if not all(isinstance(v, (str, int)) and v != "" for v in conf.values()): raise ValueError("'conf' values must be either strings or ints") return True diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py index cba505f9988a8..2829c7012d32b 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py @@ -766,7 +766,7 @@ def _get_error_code(exception: BaseException) -> str: @staticmethod def _retryable_error(exception: BaseException) -> bool: if isinstance(exception, requests_exceptions.RequestException): - if isinstance(exception, requests_exceptions.ConnectionError | requests_exceptions.Timeout) or ( + if isinstance(exception, (requests_exceptions.ConnectionError, requests_exceptions.Timeout)) or ( exception.response is not None and ( exception.response.status_code >= 500 @@ -783,7 +783,7 @@ def _retryable_error(exception: BaseException) -> bool: if exception.status >= 500 or exception.status == 429: return True - if isinstance(exception, ClientConnectorError | TimeoutError): + if isinstance(exception, (ClientConnectorError, TimeoutError)): return True return False diff --git a/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py b/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py index 86ec7ca385d05..49f9ed874e6bc 100644 --- a/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py +++ b/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py @@ -205,11 +205,11 @@ def _generate_partition_query( if isinstance(partition_value, list): output_list.append(f"""{partition_col} in {tuple(partition_value)}""") self.log.debug("List formatting for partitions: %s", output_list) - if isinstance(partition_value, int | float | complex): + if isinstance(partition_value, (int, float, complex)): output_list.append( f"""{partition_col}{self.partition_operator}{self.escaper.escape_item(partition_value)}""" ) - if isinstance(partition_value, str | datetime): + if isinstance(partition_value, (str, datetime)): output_list.append( f"""{partition_col}{self.partition_operator}{self.escaper.escape_item(partition_value)}""" ) diff --git a/providers/databricks/src/airflow/providers/databricks/utils/databricks.py b/providers/databricks/src/airflow/providers/databricks/utils/databricks.py index 8159bce1aff32..e19bccf290a36 100644 --- a/providers/databricks/src/airflow/providers/databricks/utils/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/utils/databricks.py @@ -34,12 +34,12 @@ def normalise_json_content(content, json_path: str = "json") -> str | bool | lis to string type because databricks does not understand 'True' or 'False' values. """ normalise = normalise_json_content - if isinstance(content, str | bool): + if isinstance(content, (str, bool)): return content - if isinstance(content, int | float): + if isinstance(content, (int, float)): # Databricks can tolerate either numeric or string types in the API backend. return str(content) - if isinstance(content, list | tuple): + if isinstance(content, (list, tuple)): return [normalise(e, f"{json_path}[{i}]") for i, e in enumerate(content)] if isinstance(content, dict): return {k: normalise(v, f"{json_path}[{k}]") for k, v in content.items()} diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py index e9c37d926dbcb..4abd415480bcc 100644 --- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py +++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py @@ -117,7 +117,7 @@ class DbtCloudJobRunStatus(Enum): @classmethod def check_is_valid(cls, statuses: int | Sequence[int] | set[int]): """Validate input statuses are a known value.""" - if isinstance(statuses, Sequence | set): + if isinstance(statuses, (Sequence, set)): for status in statuses: cls(status) else: diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_response.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_response.py index 221a45fdb3b94..610b03f96e199 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_response.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_response.py @@ -130,7 +130,7 @@ def __iter__(self) -> Iterator[Hit]: def __getitem__(self, key): """Retrieve a specific hit or a slice of hits from the Elasticsearch response.""" - if isinstance(key, slice | int): + if isinstance(key, (slice, int)): return self.hits[key] return super().__getitem__(key) diff --git a/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/utilities/__init__.py b/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/utilities/__init__.py index abd3606b1e056..50b883e0f02f3 100644 --- a/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/utilities/__init__.py +++ b/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/utilities/__init__.py @@ -177,7 +177,7 @@ def _base64_auth_header(auth_value): and returns a base64-encoded string to be used as an HTTP authorization header. """ - if isinstance(auth_value, list | tuple): + if isinstance(auth_value, (list, tuple)): auth_value = base64.b64encode(to_bytes(":".join(auth_value))) return to_str(auth_value) @@ -189,11 +189,11 @@ def _escape(value): """ # make sequences into comma-separated strings - if isinstance(value, list | tuple): + if isinstance(value, (list, tuple)): value = ",".join(value) # dates and datetimes into isoformat - elif isinstance(value, date | datetime): + elif isinstance(value, (date, datetime)): value = value.isoformat() # make bools into true/false strings diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py index 64054d1460802..5aede4a0465e7 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py @@ -57,7 +57,7 @@ def serialize_cluster_obj(self, cluster_obj: resources.Cluster) -> dict: """Serialize Cluster dataclass to dict.""" def __encode_value(value: Any) -> Any: - if isinstance(value, list | Repeated): + if isinstance(value, (list, Repeated)): return [__encode_value(nested_value) for nested_value in value] if isinstance(value, ScalarMapContainer): return {key: __encode_value(nested_value) for key, nested_value in dict(value).items()} diff --git a/providers/google/src/airflow/providers/google/cloud/links/base.py b/providers/google/src/airflow/providers/google/cloud/links/base.py index d543b495ab10c..ea6d8850c95f5 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/base.py +++ b/providers/google/src/airflow/providers/google/cloud/links/base.py @@ -107,7 +107,7 @@ def get_link( ti_key: TaskInstanceKey, ) -> str: if TYPE_CHECKING: - assert isinstance(operator, GoogleCloudBaseOperator | BaseSensorOperator) + assert isinstance(operator, (GoogleCloudBaseOperator, BaseSensorOperator)) conf = self.get_config(operator, ti_key) if not conf: diff --git a/providers/google/src/airflow/providers/google/cloud/sensors/bigquery_dts.py b/providers/google/src/airflow/providers/google/cloud/sensors/bigquery_dts.py index d7572a05a0fb0..c89763f9b6c49 100644 --- a/providers/google/src/airflow/providers/google/cloud/sensors/bigquery_dts.py +++ b/providers/google/src/airflow/providers/google/cloud/sensors/bigquery_dts.py @@ -112,7 +112,7 @@ def __init__( self.location = location def _normalize_state_list(self, states) -> set[TransferState]: - states = {states} if isinstance(states, str | TransferState | int) else states + states = {states} if isinstance(states, (str, TransferState, int)) else states result = set() for state in states: if isinstance(state, str): diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py index 7c40c8fae5d47..ce56036de6214 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py @@ -259,7 +259,7 @@ def generate_data_dict(self, names: Iterable[str], values: Any) -> dict[str, Any def convert_value(self, value: Any | None) -> Any | None: """Convert value to BQ type.""" - if not value or isinstance(value, str | int | float | bool | dict): + if not value or isinstance(value, (str, int, float, bool, dict)): return value if isinstance(value, bytes): return b64encode(value).decode("ascii") @@ -267,13 +267,13 @@ def convert_value(self, value: Any | None) -> Any | None: if self.encode_uuid: return b64encode(value.bytes).decode("ascii") return str(value) - if isinstance(value, datetime | Date): + if isinstance(value, (datetime, Date)): return str(value) if isinstance(value, Decimal): return float(value) if isinstance(value, Time): return str(value).split(".")[0] - if isinstance(value, list | SortedSet): + if isinstance(value, (list, SortedSet)): return self.convert_array_types(value) if hasattr(value, "_fields"): return self.convert_user_type(value) diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/mssql_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/mssql_to_gcs.py index 17fd87e9284c5..8f861aaee690a 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/mssql_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/mssql_to_gcs.py @@ -114,7 +114,7 @@ def convert_type(cls, value, schema_type, **kwargs): """ if isinstance(value, decimal.Decimal): return float(value) - if isinstance(value, datetime.date | datetime.time): + if isinstance(value, (datetime.date, datetime.time)): return value.isoformat() return value diff --git a/providers/google/src/airflow/providers/google/common/hooks/base_google.py b/providers/google/src/airflow/providers/google/common/hooks/base_google.py index 524acb030c8f3..75ca159906855 100644 --- a/providers/google/src/airflow/providers/google/common/hooks/base_google.py +++ b/providers/google/src/airflow/providers/google/common/hooks/base_google.py @@ -93,7 +93,7 @@ def is_soft_quota_exception(exception: Exception): if isinstance(exception, Forbidden): return any(reason in error.details() for reason in INVALID_REASONS for error in exception.errors) - if isinstance(exception, ResourceExhausted | TooManyRequests): + if isinstance(exception, (ResourceExhausted, TooManyRequests)): return any(key in error.details() for key in INVALID_KEYS for error in exception.errors) return False diff --git a/providers/google/src/airflow/providers/google/suite/transfers/sql_to_sheets.py b/providers/google/src/airflow/providers/google/suite/transfers/sql_to_sheets.py index e791e3adb880c..ed9243a0c900c 100644 --- a/providers/google/src/airflow/providers/google/suite/transfers/sql_to_sheets.py +++ b/providers/google/src/airflow/providers/google/suite/transfers/sql_to_sheets.py @@ -88,7 +88,7 @@ def _data_prep(self, data): for row in data: item_list = [] for item in row: - if isinstance(item, datetime.date | datetime.datetime): + if isinstance(item, (datetime.date, datetime.datetime)): item = item.isoformat() elif isinstance(item, int): # To exclude int from the number check. pass diff --git a/providers/openai/src/airflow/providers/openai/operators/openai.py b/providers/openai/src/airflow/providers/openai/operators/openai.py index 77bf82b665ef4..297b6b86b6474 100644 --- a/providers/openai/src/airflow/providers/openai/operators/openai.py +++ b/providers/openai/src/airflow/providers/openai/operators/openai.py @@ -75,7 +75,7 @@ def hook(self) -> OpenAIHook: return OpenAIHook(conn_id=self.conn_id) def execute(self, context: Context) -> list[float]: - if not self.input_text or not isinstance(self.input_text, str | list): + if not self.input_text or not isinstance(self.input_text, (str, list)): raise ValueError( "The 'input_text' must be a non-empty string, list of strings, list of integers, or list of lists of integers." ) diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py index 9ca1ba427ee60..0f5704cb2d040 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py @@ -233,7 +233,7 @@ def get_user_provided_run_facets(ti: TaskInstance, ti_state: TaskInstanceState) def get_fully_qualified_class_name(operator: BaseOperator | MappedOperator) -> str: - if isinstance(operator, MappedOperator | SerializedBaseOperator): + if isinstance(operator, (MappedOperator, SerializedBaseOperator)): # as in airflow.api_connexion.schemas.common_schema.ClassReferenceSchema return operator._task_module + "." + operator._task_type # type: ignore op_class = get_operator_class(operator) @@ -250,7 +250,7 @@ def is_selective_lineage_enabled(obj: DAG | BaseOperator | MappedOperator) -> bo return True if isinstance(obj, DAG): return is_dag_lineage_enabled(obj) - if isinstance(obj, BaseOperator | MappedOperator): + if isinstance(obj, (BaseOperator, MappedOperator)): return is_task_lineage_enabled(obj) raise TypeError("is_selective_lineage_enabled can only be used on DAG or Operator objects") @@ -328,7 +328,7 @@ def _cast_basic_types(value): return value.isoformat() if isinstance(value, datetime.timedelta): return f"{value.total_seconds()} seconds" - if isinstance(value, set | list | tuple): + if isinstance(value, (set, list, tuple)): return str(list(value)) return value @@ -403,7 +403,7 @@ def serialize_timetable(cls, dag: DAG) -> dict[str, Any]: return serialized def _serialize_ds(ds: BaseDatasetEventInput) -> dict[str, Any]: - if isinstance(ds, DatasetAny | DatasetAll): + if isinstance(ds, (DatasetAny, DatasetAll)): return { "__type": "dataset_all" if isinstance(ds, DatasetAll) else "dataset_any", "objects": [_serialize_ds(child) for child in ds.objects], diff --git a/providers/opensearch/src/airflow/providers/opensearch/log/os_response.py b/providers/opensearch/src/airflow/providers/opensearch/log/os_response.py index 4479e21b9ce45..2827cb0f04547 100644 --- a/providers/opensearch/src/airflow/providers/opensearch/log/os_response.py +++ b/providers/opensearch/src/airflow/providers/opensearch/log/os_response.py @@ -130,7 +130,7 @@ def __iter__(self) -> Iterator[Hit]: def __getitem__(self, key): """Retrieve a specific hit or a slice of hits from the Elasticsearch response.""" - if isinstance(key, slice | int): + if isinstance(key, (slice, int)): return self.hits[key] return super().__getitem__(key) diff --git a/providers/oracle/src/airflow/providers/oracle/hooks/oracle.py b/providers/oracle/src/airflow/providers/oracle/hooks/oracle.py index 355a7b0d8f978..39b32fd119494 100644 --- a/providers/oracle/src/airflow/providers/oracle/hooks/oracle.py +++ b/providers/oracle/src/airflow/providers/oracle/hooks/oracle.py @@ -156,14 +156,14 @@ def get_conn(self) -> oracledb.Connection: if thick_mode is True: if self.thick_mode_lib_dir is None: self.thick_mode_lib_dir = conn.extra_dejson.get("thick_mode_lib_dir") - if not isinstance(self.thick_mode_lib_dir, str | type(None)): + if not isinstance(self.thick_mode_lib_dir, (str, type(None))): raise TypeError( f"thick_mode_lib_dir expected str or None, " f"got {type(self.thick_mode_lib_dir).__name__}" ) if self.thick_mode_config_dir is None: self.thick_mode_config_dir = conn.extra_dejson.get("thick_mode_config_dir") - if not isinstance(self.thick_mode_config_dir, str | type(None)): + if not isinstance(self.thick_mode_config_dir, (str, type(None))): raise TypeError( f"thick_mode_config_dir expected str or None, " f"got {type(self.thick_mode_config_dir).__name__}" diff --git a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py index 3c506656c3d6d..e46839f598179 100644 --- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py @@ -228,7 +228,7 @@ def _serialize_cell(cell: object, conn: connection | None = None) -> Any: :param conn: The database connection :return: The cell """ - if isinstance(cell, dict | list): + if isinstance(cell, (dict, list)): cell = Json(cell) return cell diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index 0c4719bef93e2..ff55bfa21abae 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -49,7 +49,7 @@ def _try_to_boolean(value: Any): - if isinstance(value, str | type(None)): + if isinstance(value, (str, type(None))): return to_boolean(value) return value diff --git a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py index 998fd4c8d7e55..7618ef16b2927 100644 --- a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py +++ b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py @@ -192,7 +192,7 @@ def __init__( self.logical_date = logical_date if logical_date is NOTSET: self.logical_date = NOTSET - elif logical_date is None or isinstance(logical_date, str | datetime.datetime): + elif logical_date is None or isinstance(logical_date, (str, datetime.datetime)): self.logical_date = logical_date else: raise TypeError( diff --git a/providers/standard/src/airflow/providers/standard/sensors/external_task.py b/providers/standard/src/airflow/providers/standard/sensors/external_task.py index 32ad0111295bd..9912fe96d70d0 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py +++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py @@ -73,7 +73,7 @@ class ExternalDagLink(BaseOperatorLink): def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str: if TYPE_CHECKING: - assert isinstance(operator, ExternalTaskMarker | ExternalTaskSensor) + assert isinstance(operator, (ExternalTaskMarker, ExternalTaskSensor)) external_dag_id = operator.external_dag_id diff --git a/providers/standard/src/airflow/providers/standard/utils/skipmixin.py b/providers/standard/src/airflow/providers/standard/utils/skipmixin.py index bd8e31bc72a34..32eb545f47019 100644 --- a/providers/standard/src/airflow/providers/standard/utils/skipmixin.py +++ b/providers/standard/src/airflow/providers/standard/utils/skipmixin.py @@ -52,7 +52,7 @@ def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]: from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef] from airflow.models.mappedoperator import MappedOperator # type: ignore[no-redef] - return [n for n in nodes if isinstance(n, BaseOperator | MappedOperator)] + return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))] # This class should only be used in Airflow 3.0 and later. diff --git a/providers/teradata/src/airflow/providers/teradata/hooks/bteq.py b/providers/teradata/src/airflow/providers/teradata/hooks/bteq.py index 2d9b0904e48f9..dae400a638825 100644 --- a/providers/teradata/src/airflow/providers/teradata/hooks/bteq.py +++ b/providers/teradata/src/airflow/providers/teradata/hooks/bteq.py @@ -200,7 +200,7 @@ def _transfer_to_and_execute_bteq_on_remote( and exit_status not in ( bteq_quit_rc - if isinstance(bteq_quit_rc, list | tuple) + if isinstance(bteq_quit_rc, (list, tuple)) else [bteq_quit_rc if bteq_quit_rc is not None else 0] ) ): @@ -298,7 +298,7 @@ def execute_bteq_script_at_local( and process.returncode not in ( bteq_quit_rc - if isinstance(bteq_quit_rc, list | tuple) + if isinstance(bteq_quit_rc, (list, tuple)) else [bteq_quit_rc if bteq_quit_rc is not None else 0] ) ): diff --git a/pyproject.toml b/pyproject.toml index a7dbd581e30fb..ffd562e6a9b00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -650,6 +650,7 @@ ignore = [ "COM819", "E501", # Formatted code may exceed the line length, leading to line-too-long (E501) errors. "ASYNC110", # TODO: Use `anyio.Event` instead of awaiting `anyio.sleep` in a `while` loop + "UP038", ] unfixable = [ # PT022 replace empty `yield` to empty `return`. Might be fixed with a combination of PLR1711 diff --git a/scripts/ci/pre_commit/chart_schema.py b/scripts/ci/pre_commit/chart_schema.py index 5257fcca54c49..de720573f7ef3 100755 --- a/scripts/ci/pre_commit/chart_schema.py +++ b/scripts/ci/pre_commit/chart_schema.py @@ -58,7 +58,7 @@ def walk(value, path="$"): if isinstance(value, dict): for k, v in value.items(): yield from walk(v, path + f"[{k!r}]") - elif isinstance(value, list | set | tuple): + elif isinstance(value, (list, set, tuple)): for no, v in enumerate(value): yield from walk(v, path + f"[{no}]") diff --git a/scripts/ci/pre_commit/check_deprecations.py b/scripts/ci/pre_commit/check_deprecations.py index f20d2c55e0459..c1fb6354c9a91 100755 --- a/scripts/ci/pre_commit/check_deprecations.py +++ b/scripts/ci/pre_commit/check_deprecations.py @@ -138,7 +138,7 @@ def built_import(import_clause: ast.Import) -> list[str]: def found_compatible_decorators(mod: ast.Module) -> tuple[str, ...]: result = [] for node in mod.body: - if not isinstance(node, ast.ImportFrom | ast.Import): + if not isinstance(node, (ast.ImportFrom, ast.Import)): continue result.extend(built_import_from(node) if isinstance(node, ast.ImportFrom) else built_import(node)) return tuple(sorted(set(result))) @@ -193,7 +193,7 @@ def check_decorators(mod: ast.Module, file: str, file_group: str) -> int: category_value_ast = category_keyword.value warns_types = allowed_warnings[file_group] - if isinstance(category_value_ast, ast.Name | ast.Attribute): + if isinstance(category_value_ast, (ast.Name, ast.Attribute)): category_value = resolve_name(category_value_ast) if not any(cv.endswith(category_value) for cv in warns_types): errors += 1 diff --git a/scripts/in_container/run_migration_reference.py b/scripts/in_container/run_migration_reference.py index bc92e7766ae15..bdf7bf3b7667b 100755 --- a/scripts/in_container/run_migration_reference.py +++ b/scripts/in_container/run_migration_reference.py @@ -58,7 +58,7 @@ def wrap_backticks(val): def _wrap_backticks(x): return f"``{x}``" - return ",\n".join(map(_wrap_backticks, val)) if isinstance(val, tuple | list) else _wrap_backticks(val) + return ",\n".join(map(_wrap_backticks, val)) if isinstance(val, (tuple, list)) else _wrap_backticks(val) def update_doc(file, data, app): diff --git a/task-sdk/src/airflow/sdk/bases/decorator.py b/task-sdk/src/airflow/sdk/bases/decorator.py index fb77b079a3386..39e4560b47da7 100644 --- a/task-sdk/src/airflow/sdk/bases/decorator.py +++ b/task-sdk/src/airflow/sdk/bases/decorator.py @@ -424,7 +424,7 @@ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = ) if isinstance(kwargs, Sequence): for item in kwargs: - if not isinstance(item, XComArg | Mapping): + if not isinstance(item, (XComArg, Mapping)): raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") elif not isinstance(kwargs, XComArg): raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") diff --git a/task-sdk/src/airflow/sdk/bases/sensor.py b/task-sdk/src/airflow/sdk/bases/sensor.py index a912476322c15..a39de2a151873 100644 --- a/task-sdk/src/airflow/sdk/bases/sensor.py +++ b/task-sdk/src/airflow/sdk/bases/sensor.py @@ -143,7 +143,7 @@ def __init__( def _coerce_poke_interval(poke_interval: float | timedelta) -> timedelta: if isinstance(poke_interval, timedelta): return poke_interval - if isinstance(poke_interval, int | float) and poke_interval >= 0: + if isinstance(poke_interval, (int, float)) and poke_interval >= 0: return timedelta(seconds=poke_interval) raise AirflowException( "Operator arg `poke_interval` must be timedelta object or a non-negative number" @@ -153,7 +153,7 @@ def _coerce_poke_interval(poke_interval: float | timedelta) -> timedelta: def _coerce_timeout(timeout: float | timedelta) -> timedelta: if isinstance(timeout, timedelta): return timeout - if isinstance(timeout, int | float) and timeout >= 0: + if isinstance(timeout, (int, float)) and timeout >= 0: return timedelta(seconds=timeout) raise AirflowException("Operator arg `timeout` must be timedelta object or a non-negative number") @@ -161,14 +161,14 @@ def _coerce_timeout(timeout: float | timedelta) -> timedelta: def _coerce_max_wait(max_wait: float | timedelta | None) -> timedelta | None: if max_wait is None or isinstance(max_wait, timedelta): return max_wait - if isinstance(max_wait, int | float) and max_wait >= 0: + if isinstance(max_wait, (int, float)) and max_wait >= 0: return timedelta(seconds=max_wait) raise AirflowException("Operator arg `max_wait` must be timedelta object or a non-negative number") def _validate_input_values(self) -> None: - if not isinstance(self.poke_interval, int | float) or self.poke_interval < 0: + if not isinstance(self.poke_interval, (int, float)) or self.poke_interval < 0: raise AirflowException("The poke_interval must be a non-negative number") - if not isinstance(self.timeout, int | float) or self.timeout < 0: + if not isinstance(self.timeout, (int, float)) or self.timeout < 0: raise AirflowException("The timeout must be a non-negative number") if self.mode not in self.valid_modes: raise AirflowException( diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py index 4180c4edaa4db..ec2fefa0a08b4 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py @@ -426,7 +426,7 @@ def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]: for key, child in _walk_group(dag.task_group): if key == self.node_id: continue - if not isinstance(child, MappedOperator | MappedTaskGroup): + if not isinstance(child, (MappedOperator, MappedTaskGroup)): continue if self.node_id in child.upstream_task_ids: yield child diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py b/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py index 00b12e24b399d..b1c0c6ee5f979 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py @@ -62,21 +62,21 @@ def __str__(self) -> str: def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]: from airflow.sdk.definitions.xcom_arg import XComArg - return isinstance(v, MappedArgument | XComArg | Mapping | Sequence) and not isinstance(v, str) + return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str) # To replace tedious isinstance() checks. def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]: from airflow.sdk.definitions.xcom_arg import XComArg - return not isinstance(v, MappedArgument | XComArg) + return not isinstance(v, (MappedArgument, XComArg)) # To replace tedious isinstance() checks. def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]: from airflow.sdk.definitions.xcom_arg import XComArg - return isinstance(v, MappedArgument | XComArg) + return isinstance(v, (MappedArgument, XComArg)) @attrs.define(kw_only=True) diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/node.py b/task-sdk/src/airflow/sdk/definitions/_internal/node.py index 177111af541c0..21fa4ede5b1c9 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/node.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/node.py @@ -168,7 +168,7 @@ def _set_relatives( task_object.update_relative(self, not upstream, edge_modifier=edge_modifier) relatives = task_object.leaves if upstream else task_object.roots for task in relatives: - if not isinstance(task, BaseOperator | MappedOperator): + if not isinstance(task, (BaseOperator, MappedOperator)): raise TypeError( f"Relationships can only be set between Operators; received {task.__class__.__name__}" ) diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 5291a38651486..9e728d6812844 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -122,7 +122,7 @@ def _create_timetable(interval: ScheduleInterval, timezone: Timezone | FixedTime return OnceTimetable() if interval == "@continuous": return ContinuousTimetable() - if isinstance(interval, timedelta | relativedelta): + if isinstance(interval, (timedelta, relativedelta)): if airflow_conf.getboolean("scheduler", "create_cron_data_intervals"): return DeltaDataIntervalTimetable(interval) return DeltaTriggerTimetable(interval) @@ -807,7 +807,7 @@ def partial_subset( direct_upstreams: list[Operator] = [] if include_direct_upstream: for t in itertools.chain(matched_tasks, also_include): - upstream = (u for u in t.upstream_list if isinstance(u, BaseOperator | MappedOperator)) + upstream = (u for u in t.upstream_list if isinstance(u, (BaseOperator, MappedOperator))) direct_upstreams.extend(upstream) # Make sure to not recursively deepcopy the dag or task_group while copying the task. diff --git a/task-sdk/src/airflow/sdk/definitions/decorators/task_group.py b/task-sdk/src/airflow/sdk/definitions/decorators/task_group.py index 2c1e2c70e0675..c77b65bf48e34 100644 --- a/task-sdk/src/airflow/sdk/definitions/decorators/task_group.py +++ b/task-sdk/src/airflow/sdk/definitions/decorators/task_group.py @@ -150,7 +150,7 @@ def expand(self, **kwargs: OperatorExpandArgument) -> DAGNode: def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument) -> DAGNode: if isinstance(kwargs, Sequence): for item in kwargs: - if not isinstance(item, XComArg | Mapping): + if not isinstance(item, (XComArg, Mapping)): raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") elif not isinstance(kwargs, XComArg): raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") diff --git a/task-sdk/src/airflow/sdk/definitions/edges.py b/task-sdk/src/airflow/sdk/definitions/edges.py index 4a52b9df32c4e..39fafc4b932c8 100644 --- a/task-sdk/src/airflow/sdk/definitions/edges.py +++ b/task-sdk/src/airflow/sdk/definitions/edges.py @@ -75,7 +75,7 @@ def _save_nodes( from airflow.sdk.definitions.xcom_arg import XComArg for node in self._make_list(nodes): - if isinstance(node, TaskGroup | XComArg | DAGNode): + if isinstance(node, (TaskGroup, XComArg, DAGNode)): stream.append(node) else: raise TypeError( diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py index 9f2cc18604354..7fddf969137fb 100644 --- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -140,9 +140,9 @@ def is_mappable_value(value: Any) -> TypeGuard[Collection]: :meta private: """ - if not isinstance(value, Sequence | dict): + if not isinstance(value, (Sequence, dict)): return False - if isinstance(value, bytearray | bytes | str): + if isinstance(value, (bytearray, bytes, str)): return False return True @@ -192,7 +192,7 @@ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = if isinstance(kwargs, Sequence): for item in kwargs: - if not isinstance(item, XComArg | Mapping): + if not isinstance(item, (XComArg, Mapping)): raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") elif not isinstance(kwargs, XComArg): raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 7a5ed0468e739..162a63db5c884 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -104,7 +104,7 @@ def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]: """ if isinstance(arg, ResolveMixin): yield from arg.iter_references() - elif isinstance(arg, tuple | set | list): + elif isinstance(arg, (tuple, set, list)): for elem in arg: yield from XComArg.iter_xcom_references(elem) elif isinstance(arg, dict): @@ -429,7 +429,7 @@ def map(self, f: Callable[[Any], Any]) -> MapXComArg: def resolve(self, context: Mapping[str, Any]) -> Any: value = self.arg.resolve(context) - if not isinstance(value, Sequence | dict): + if not isinstance(value, (Sequence, dict)): raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") return _MapResult(value, self.callables) @@ -494,7 +494,7 @@ def iter_references(self) -> Iterator[tuple[Operator, str]]: def resolve(self, context: Mapping[str, Any]) -> Any: values = [arg.resolve(context) for arg in self.args] for value in values: - if not isinstance(value, Sequence | dict): + if not isinstance(value, (Sequence, dict)): raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}") return _ZipResult(values, fillvalue=self.fillvalue) @@ -557,7 +557,7 @@ def concat(self, *others: XComArg) -> ConcatXComArg: def resolve(self, context: Mapping[str, Any]) -> Any: values = [arg.resolve(context) for arg in self.args] for value in values: - if not isinstance(value, Sequence | dict): + if not isinstance(value, (Sequence, dict)): raise ValueError(f"XCom concat expects sequence or dict, not {type(value).__name__}") return _ConcatResult(values) diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index e95b9b173dc90..c76994995ebab 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -529,7 +529,7 @@ def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> list[AssetEve msg: ToSupervisor if isinstance(key, int): # Support index access; it's easier for trivial cases. obj = self._inlets[key] - if not isinstance(obj, Asset | AssetAlias | AssetRef): + if not isinstance(obj, (Asset, AssetAlias, AssetRef)): raise IndexError(key) else: obj = key diff --git a/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py b/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py index 560a493d3bc3f..b36eca9cd0809 100644 --- a/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py +++ b/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py @@ -232,7 +232,7 @@ def _redact_all(self, item: Redactable, depth: int, max_depth: int = MAX_RECURSI return { dict_key: self._redact_all(subval, depth + 1, max_depth) for dict_key, subval in item.items() } - if isinstance(item, tuple | set): + if isinstance(item, (tuple, set)): # Turn set in to tuple! return tuple(self._redact_all(subval, depth + 1, max_depth) for subval in item) if isinstance(item, list): @@ -270,7 +270,7 @@ def _redact(self, item: Redactable, name: str | None, depth: int, max_depth: int # the structure. return self.replacer.sub("***", str(item)) return item - if isinstance(item, tuple | set): + if isinstance(item, (tuple, set)): # Turn set in to tuple! return tuple( self._redact(subval, name=None, depth=(depth + 1), max_depth=max_depth) for subval in item diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 24127d65a2763..6c6e597f65e5c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -340,8 +340,8 @@ def xcom_pull( if run_id is None: run_id = self.run_id - single_task_requested = isinstance(task_ids, str | type(None)) - single_map_index_requested = isinstance(map_indexes, int | type(None)) + single_task_requested = isinstance(task_ids, (str, type(None))) + single_map_index_requested = isinstance(map_indexes, (int, type(None))) if task_ids is None: # default to the current task if not provided @@ -618,7 +618,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: ) exit(1) - if not isinstance(task, BaseOperator | MappedOperator): + if not isinstance(task, (BaseOperator, MappedOperator)): raise TypeError( f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}" ) diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py index f2660faa2d206..80c71f41a8fe8 100644 --- a/task-sdk/tests/conftest.py +++ b/task-sdk/tests/conftest.py @@ -119,7 +119,7 @@ def captured_logs(request): # We need to replace remove the last processor (the one that turns JSON into text, as we want the # event dict for tests) proc = processors.pop() - assert isinstance(proc, structlog.dev.ConsoleRenderer | structlog.processors.JSONRenderer), ( + assert isinstance(proc, (structlog.dev.ConsoleRenderer, structlog.processors.JSONRenderer)), ( "Pre-condition" ) try: diff --git a/task-sdk/tests/task_sdk/definitions/conftest.py b/task-sdk/tests/task_sdk/definitions/conftest.py index 7ad358487ba63..3f89f34b4d2da 100644 --- a/task-sdk/tests/task_sdk/definitions/conftest.py +++ b/task-sdk/tests/task_sdk/definitions/conftest.py @@ -42,7 +42,7 @@ def run(dag: DAG, task_id: str, map_index: int): for call in mock_supervisor_comms.send.mock_calls: msg = call.kwargs.get("msg") or call.args[0] - if isinstance(msg, TaskState | SucceedTask): + if isinstance(msg, (TaskState, SucceedTask)): return msg.state raise RuntimeError("Unable to find call to TaskState") diff --git a/task-sdk/tests/task_sdk/definitions/test_asset.py b/task-sdk/tests/task_sdk/definitions/test_asset.py index 34992c73f8673..7bd45edd8f99a 100644 --- a/task-sdk/tests/task_sdk/definitions/test_asset.py +++ b/task-sdk/tests/task_sdk/definitions/test_asset.py @@ -244,7 +244,7 @@ def assets_equal(a1: BaseAsset, a2: BaseAsset) -> bool: if isinstance(a1, Asset) and isinstance(a2, Asset): return a1.uri == a2.uri - if isinstance(a1, AssetAny | AssetAll) and isinstance(a2, AssetAny | AssetAll): + if isinstance(a1, (AssetAny, AssetAll)) and isinstance(a2, AssetAny | AssetAll): if len(a1.objects) != len(a2.objects): return False