From 72d268022f77cc26bb8f3aac031fcbce17ca7607 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Wed, 21 Jan 2026 15:51:17 +0100 Subject: [PATCH 01/27] initial ty integration --- .circleci/config.yml | 1 + Makefile | 6 ++++++ docker/quality.dockerfile | 2 +- pyproject.toml | 20 +++++++++++++++++++ setup.py | 3 ++- src/transformers/dependency_versions_table.py | 1 + 6 files changed, 31 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 7875cdc368f5..4c6ff48cd482 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -156,6 +156,7 @@ jobs: path: ~/transformers/installed.txt - run: ruff check examples tests src utils scripts benchmark benchmark_v2 setup.py conftest.py - run: ruff format --check examples tests src utils scripts benchmark benchmark_v2 setup.py conftest.py + - run: ty check examples tests src utils scripts benchmark benchmark_v2 setup.py conftest.py - run: python utils/custom_init_isort.py --check_only - run: python utils/sort_auto_mappings.py --check_only diff --git a/Makefile b/Makefile index 8b3b4dc2acba..e151113c4660 100644 --- a/Makefile +++ b/Makefile @@ -6,6 +6,11 @@ export PYTHONPATH = src check_dirs := examples tests src utils scripts benchmark benchmark_v2 exclude_folders := "" +# Helper to find all Python files in directories (ty doesn't recursively scan directories) +define get_py_files +$(shell find $(1) -name "*.py" -type f 2>/dev/null) +endef + # this runs all linting/formatting scripts, most notably ruff style: @@ -20,6 +25,7 @@ style: check-repo: ruff check $(check_dirs) setup.py conftest.py ruff format --check $(check_dirs) setup.py conftest.py + -ty check $(call get_py_files,src/transformers) setup.py conftest.py -python utils/custom_init_isort.py --check_only -python utils/sort_auto_mappings.py --check_only -python -c "from transformers import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1) diff --git a/docker/quality.dockerfile b/docker/quality.dockerfile index 6455a27d642b..97987b0d098d 100644 --- a/docker/quality.dockerfile +++ b/docker/quality.dockerfile @@ -5,5 +5,5 @@ USER root RUN apt-get update && apt-get install -y time git ENV UV_PYTHON=/usr/local/bin/python RUN pip install uv -RUN uv pip install --no-cache-dir -U pip setuptools GitPython "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[ruff]" urllib3 +RUN uv pip install --no-cache-dir -U pip setuptools GitPython "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[quality]" urllib3 RUN apt-get install -y jq curl && apt-get clean && rm -rf /var/lib/apt/lists/* diff --git a/pyproject.toml b/pyproject.toml index 2705851dd49a..c138b905cd21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,3 +92,23 @@ env = [ # Note: 'D:' means default value from laptop or CI won't be overwritten "D:HF_HUB_DOWNLOAD_TIMEOUT=60", ] + +[tool.ty] +# ty type checker configuration +# Using default settings for comprehensive type checking + +[tool.ty.rules] +# Disable specific rules that produce false positives or are too strict for this codebase +invalid-method-override = "ignore" # Parameter name differences are acceptable (e.g., x vs input, new_embeddings vs value) +not-subscriptable = "ignore" # False positives on tensor slicing (e.g., self.position_ids[:, :seq_length]) +no-matching-overload = "ignore" # False positives on torch.zeros and similar functions accepting Size/tuple +unsupported-operator = "ignore" # False positives on tuple concatenation with += when properly initialized +unresolved-import = "ignore" # Optional dependencies (mlx, torch_npu, habana_frameworks, etc.) checked at runtime +call-non-callable = "ignore" # Mixin pattern issues where classes are used as both types and callables +unresolved-reference = "ignore" # Forward references with noqa: F821 that ty doesn't respect +invalid-argument-type = "ignore" # Complex type narrowing and union type issues +not-iterable = "ignore" # Complex async/Future type patterns +invalid-return-type = "ignore" # Return type mismatches that would require refactoring +deprecated = "ignore" # Deprecation warnings from dependencies +invalid-assignment = "ignore" # Low-level assignments that are runtime-safe +unused-ignore-comment = "ignore" # Ignore comments that became unnecessary after adding broader per-file-ignores diff --git a/setup.py b/setup.py index 6996881fe95e..90b08fe12f2d 100644 --- a/setup.py +++ b/setup.py @@ -124,6 +124,7 @@ "rjieba", "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "ruff==0.14.10", + "ty==0.0.12", # `sacrebleu` not used in `transformers`. However, it is needed in several tests, when a test calls # `evaluate.load("sacrebleu")`. This metric is used in the examples that we use to test the `Trainer` with, in the # `Trainer` tests (see references to `run_translation.py`). @@ -180,7 +181,7 @@ def deps_list(*pkgs): extras["audio"] += deps_list("kenlm") extras["video"] = deps_list("av") extras["timm"] = deps_list("timm") -extras["quality"] = deps_list("datasets", "ruff", "GitPython", "urllib3", "libcst", "rich") +extras["quality"] = deps_list("datasets", "ruff", "GitPython", "urllib3", "libcst", "rich", "ty") extras["kernels"] = deps_list("kernels") extras["sentencepiece"] = deps_list("sentencepiece", "protobuf") extras["tiktoken"] = deps_list("tiktoken", "blobfile") diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 9f0315bbbe15..9c121dabafb1 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -56,6 +56,7 @@ "rjieba": "rjieba", "rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "ruff": "ruff==0.14.10", + "ty": "ty==0.0.12", "sacrebleu": "sacrebleu>=1.4.12,<2.0.0", "sacremoses": "sacremoses", "safetensors": "safetensors>=0.4.3", From 30456177e73ed78e56385cefbbe764cd1d4310d2 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Wed, 21 Jan 2026 17:46:51 +0100 Subject: [PATCH 02/27] narrow ty check to utils and add more ignores --- Makefile | 2 +- src/transformers/trainer_utils.py | 6 +++--- src/transformers/utils/chat_template_utils.py | 14 +++++++------- src/transformers/utils/generic.py | 4 ++-- src/transformers/utils/hub.py | 2 +- src/transformers/utils/import_utils.py | 10 +++++++++- src/transformers/utils/logging.py | 4 ++-- src/transformers/utils/type_validators.py | 4 ++-- 8 files changed, 27 insertions(+), 19 deletions(-) diff --git a/Makefile b/Makefile index e151113c4660..ba78e2a4d461 100644 --- a/Makefile +++ b/Makefile @@ -25,7 +25,7 @@ style: check-repo: ruff check $(check_dirs) setup.py conftest.py ruff format --check $(check_dirs) setup.py conftest.py - -ty check $(call get_py_files,src/transformers) setup.py conftest.py + ty check $(call get_py_files,src/transformers/utils) --force-exclude --exclude '**/*_pb2*.py' -python utils/custom_init_isort.py --check_only -python utils/sort_auto_mappings.py --check_only -python -c "from transformers import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1) diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 46582e4069c8..1d470bd5aa16 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -26,10 +26,10 @@ import shutil import threading import time -from collections.abc import Callable +from collections.abc import Callable, Sized from functools import partial from pathlib import Path -from typing import Any, NamedTuple +from typing import Any, NamedTuple, TypeGuard import numpy as np @@ -890,7 +890,7 @@ def stop_and_update_metrics(self, metrics=None): self.update_metrics(stage, metrics) -def has_length(dataset): +def has_length(dataset: Any) -> TypeGuard[Sized]: """ Checks if the dataset implements __len__() and it doesn't raise an error """ diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 1c639d205542..7c91cb78165d 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -197,7 +197,7 @@ def _convert_type_hints_to_json_schema(func: Callable) -> dict: if param_name == implicit_arg_name: continue if param.annotation == inspect.Parameter.empty: - raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}") + raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}") # type: ignore[attr-defined] if param.default == inspect.Parameter.empty: required.append(param_name) @@ -358,7 +358,7 @@ def get_json_schema(func: Callable) -> dict: doc = inspect.getdoc(func) if not doc: raise DocstringParsingException( - f"Cannot generate JSON schema for {func.__name__} because it has no docstring!" + f"Cannot generate JSON schema for {func.__name__} because it has no docstring!" # type: ignore[attr-defined] ) doc = doc.strip() main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc) @@ -370,7 +370,7 @@ def get_json_schema(func: Callable) -> dict: for arg, schema in json_schema["properties"].items(): if arg not in param_descriptions: raise DocstringParsingException( - f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" + f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" # type: ignore[attr-defined] ) desc = param_descriptions[arg] enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE) @@ -379,7 +379,7 @@ def get_json_schema(func: Callable) -> dict: desc = enum_choices.string[: enum_choices.start()].strip() schema["description"] = desc - output = {"name": func.__name__, "description": main_doc, "parameters": json_schema} + output = {"name": func.__name__, "description": main_doc, "parameters": json_schema} # type: ignore[attr-defined] if return_dict is not None: output["return"] = return_dict return {"type": "function", "function": output} @@ -421,10 +421,10 @@ def __init__(self, environment: ImmutableSandboxedEnvironment): self._rendered_blocks = None self._generation_indices = None - def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock: + def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock: # type: ignore[name-defined] lineno = next(parser.stream).lineno body = parser.parse_statements(["name:endgeneration"], drop_needle=True) - return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno) + return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno) # type: ignore[attr-defined] @jinja2.pass_eval_context def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str: @@ -488,7 +488,7 @@ def render_jinja_template( **kwargs, ) -> str: if return_assistant_tokens_mask and not re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template): - logger.warning_once( + logger.warning_once( # type: ignore[attr-defined] "return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword." ) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index d6eec7433a79..f94e17e59e41 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -462,8 +462,8 @@ def _model_output_flatten(output: ModelOutput) -> tuple[list[Any], _torch_pytree def _model_output_unflatten( values: Iterable[Any], - context: _torch_pytree.Context, - output_type=None, + context: "_torch_pytree.Context", + output_type: type[ModelOutput] | None = None, ) -> ModelOutput: return output_type(**dict(zip(context, values))) diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 127ae9bdc595..5f2d6f14e90f 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -295,7 +295,7 @@ def cached_files( _raise_exceptions_for_connection_errors: bool = True, _commit_hash: str | None = None, **deprecated_kwargs, -) -> str | None: +) -> list[str] | None: """ Tries to locate several files in a local folder and repo, downloads and cache them if necessary. diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 1e527203e917..dac9d7d8934a 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -31,7 +31,7 @@ from functools import lru_cache from itertools import chain from types import ModuleType -from typing import Any +from typing import Any, Literal, overload import packaging.version from packaging import version @@ -45,6 +45,14 @@ PACKAGE_DISTRIBUTION_MAPPING = importlib.metadata.packages_distributions() +@overload +def _is_package_available(pkg_name: str, return_version: Literal[True]) -> tuple[bool, str]: ... + + +@overload +def _is_package_available(pkg_name: str, return_version: Literal[False] = False) -> bool: ... + + def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: """Check if `pkg_name` exist, and optionally try to get its version""" spec = importlib.util.find_spec(pkg_name) diff --git a/src/transformers/utils/logging.py b/src/transformers/utils/logging.py index c9fc19f26dd7..17093880a3fa 100644 --- a/src/transformers/utils/logging.py +++ b/src/transformers/utils/logging.py @@ -327,7 +327,7 @@ def warning_once(self, *args, **kwargs): self.warning(*args, **kwargs) -logging.Logger.warning_once = warning_once +logging.Logger.warning_once = warning_once # type: ignore[attr-defined] @functools.lru_cache(None) @@ -342,7 +342,7 @@ def info_once(self, *args, **kwargs): self.info(*args, **kwargs) -logging.Logger.info_once = info_once +logging.Logger.info_once = info_once # type: ignore[attr-defined] class EmptyTqdm: diff --git a/src/transformers/utils/type_validators.py b/src/transformers/utils/type_validators.py index 8775150ece22..2128f510004b 100644 --- a/src/transformers/utils/type_validators.py +++ b/src/transformers/utils/type_validators.py @@ -98,14 +98,14 @@ def check_dict_keys(d: dict) -> bool: if isinstance(value, Sequence) and isinstance(value[0], Sequence) and isinstance(value[0][0], dict): for sublist in value: for item in sublist: - if not check_dict_keys(item): + if not check_dict_keys(item): # type: ignore[arg-type] raise ValueError( f"Invalid keys found in video metadata. Valid keys: {valid_keys} got: {list(item.keys())}" ) elif isinstance(value, Sequence) and isinstance(value[0], dict): for item in value: - if not check_dict_keys(item): + if not check_dict_keys(item): # type: ignore[arg-type] raise ValueError( f"Invalid keys found in video metadata. Valid keys: {valid_keys} got: {list(item.keys())}" ) From f92bb1c187c4daafe072862adf5d6b5b5a50fd40 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Wed, 21 Jan 2026 18:04:52 +0100 Subject: [PATCH 03/27] circleci does not use the makefile, fix ty changes there --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 4c6ff48cd482..ab63a3823c2f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -156,7 +156,7 @@ jobs: path: ~/transformers/installed.txt - run: ruff check examples tests src utils scripts benchmark benchmark_v2 setup.py conftest.py - run: ruff format --check examples tests src utils scripts benchmark benchmark_v2 setup.py conftest.py - - run: ty check examples tests src utils scripts benchmark benchmark_v2 setup.py conftest.py + - run: ty check src/transformers/utils/*.py --force-exclude --exclude '**/*_pb2*.py' - run: python utils/custom_init_isort.py --check_only - run: python utils/sort_auto_mappings.py --check_only From 5a7ca7e0f91a09fa20c4dc12541537b3d50c9764 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Thu, 22 Jan 2026 08:53:21 +0100 Subject: [PATCH 04/27] remove unecessary ignores --- src/transformers/utils/type_validators.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/type_validators.py b/src/transformers/utils/type_validators.py index 2128f510004b..960d55d9dbfa 100644 --- a/src/transformers/utils/type_validators.py +++ b/src/transformers/utils/type_validators.py @@ -92,20 +92,20 @@ def video_metadata_validator(value: VideoMetadataType | None = None): valid_keys = ["total_num_frames", "fps", "width", "height", "duration", "video_backend", "frames_indices"] - def check_dict_keys(d: dict) -> bool: + def check_dict_keys(d: dict[str, Any]) -> bool: return all(key in valid_keys for key in d.keys()) if isinstance(value, Sequence) and isinstance(value[0], Sequence) and isinstance(value[0][0], dict): for sublist in value: for item in sublist: - if not check_dict_keys(item): # type: ignore[arg-type] + if not check_dict_keys(item): raise ValueError( f"Invalid keys found in video metadata. Valid keys: {valid_keys} got: {list(item.keys())}" ) elif isinstance(value, Sequence) and isinstance(value[0], dict): for item in value: - if not check_dict_keys(item): # type: ignore[arg-type] + if not check_dict_keys(item): raise ValueError( f"Invalid keys found in video metadata. Valid keys: {valid_keys} got: {list(item.keys())}" ) From d6595523fa939f547eb3427843e316f21202ff79 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Thu, 22 Jan 2026 09:13:54 +0100 Subject: [PATCH 05/27] removed a couple of ignores --- .../utils/attention_visualizer.py | 5 +- src/transformers/utils/chat_template_utils.py | 7 +++ src/transformers/utils/doc.py | 7 ++- src/transformers/utils/hub.py | 8 ++- src/transformers/utils/import_utils.py | 57 +++++++++++-------- src/transformers/utils/loading_report.py | 2 +- src/transformers/utils/logging.py | 9 +-- src/transformers/utils/notebook.py | 29 ++++++---- src/transformers/utils/peft_utils.py | 1 + src/transformers/utils/quantization_config.py | 11 +++- src/transformers/utils/type_validators.py | 4 +- 11 files changed, 89 insertions(+), 51 deletions(-) diff --git a/src/transformers/utils/attention_visualizer.py b/src/transformers/utils/attention_visualizer.py index 3063393492b1..c1253d36541d 100644 --- a/src/transformers/utils/attention_visualizer.py +++ b/src/transformers/utils/attention_visualizer.py @@ -222,7 +222,10 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""): ) if causal_mask is not None: - attention_mask = ~causal_mask.bool() + if hasattr(causal_mask, "bool"): + attention_mask = ~causal_mask.bool() + else: + attention_mask = ~causal_mask else: attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, 1, seq_length, seq_length) top_bottom_border = "##" * ( diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 7c91cb78165d..3231851a12c0 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -41,6 +41,10 @@ if is_jinja_available(): import jinja2 + import jinja2.exceptions + import jinja2.ext + import jinja2.nodes + import jinja2.runtime from jinja2.ext import Extension from jinja2.sandbox import ImmutableSandboxedEnvironment else: @@ -409,6 +413,7 @@ def _compile_jinja_template(chat_template): raise ImportError( "apply_chat_template requires jinja2 to be installed. Please install it using `pip install jinja2`." ) + assert jinja2 is not None class AssistantTracker(Extension): # This extension is used to track the indices of assistant-generated tokens in the rendered chat @@ -431,6 +436,7 @@ def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2. rv = caller() if self.is_active(): # Only track generation indices if the tracker is active + assert self._rendered_blocks is not None and self._generation_indices is not None start_index = len("".join(self._rendered_blocks)) end_index = start_index + len(rv) self._generation_indices.append((start_index, end_index)) @@ -458,6 +464,7 @@ def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[ ) def raise_exception(message): + assert jinja2 is not None raise jinja2.exceptions.TemplateError(message) def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False): diff --git a/src/transformers/utils/doc.py b/src/transformers/utils/doc.py index eb648a205ccc..a2113fcf297c 100644 --- a/src/transformers/utils/doc.py +++ b/src/transformers/utils/doc.py @@ -1091,6 +1091,7 @@ def copy_func(f): """Returns a copy of a function f.""" # Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard) g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__) - g = functools.update_wrapper(g, f) - g.__kwdefaults__ = f.__kwdefaults__ - return g + wrapped = functools.update_wrapper(g, f) + if hasattr(f, "__kwdefaults__"): + setattr(wrapped, "__kwdefaults__", f.__kwdefaults__) + return wrapped diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 5f2d6f14e90f..fea444606279 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -168,7 +168,8 @@ def define_sagemaker_information(): sagemaker_params = json.loads(os.getenv("SM_FRAMEWORK_PARAMS", "{}")) runs_distributed_training = "sagemaker_distributed_dataparallel_enabled" in sagemaker_params - account_id = os.getenv("TRAINING_JOB_ARN").split(":")[4] if "TRAINING_JOB_ARN" in os.environ else None + training_job_arn = os.getenv("TRAINING_JOB_ARN") + account_id = training_job_arn.split(":")[4] if training_job_arn is not None else None sagemaker_object = { "sm_framework": os.getenv("SM_FRAMEWORK_MODULE", None), @@ -773,7 +774,10 @@ def push_to_hub( with tempfile.TemporaryDirectory() as tmp_dir: # Save all files. - self.save_pretrained(tmp_dir, max_shard_size=max_shard_size) + if hasattr(self, "save_pretrained"): + self.save_pretrained(tmp_dir, max_shard_size=max_shard_size) + else: + raise AttributeError("The object must have a save_pretrained method to use push_to_hub") # Update model card model_card.save(os.path.join(tmp_dir, "README.md")) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index dac9d7d8934a..33fe454f0f53 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -238,8 +238,10 @@ def is_torch_npu_available(check_device=False) -> bool: if check_device: try: # Will raise a RuntimeError if no NPU is found - _ = torch.npu.device_count() - return torch.npu.is_available() + if hasattr(torch, "npu"): + _ = torch.npu.device_count() + return torch.npu.is_available() + return False except RuntimeError: return False return hasattr(torch, "npu") and torch.npu.is_available() @@ -285,7 +287,7 @@ def is_torch_mlu_available() -> bool: pytorch_cndev_based_mlu_check_previous_value = os.environ.get("PYTORCH_CNDEV_BASED_MLU_CHECK") try: os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = str(1) - available = torch.mlu.is_available() + available = torch.mlu.is_available() if hasattr(torch, "mlu") else False finally: if pytorch_cndev_based_mlu_check_previous_value: os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = pytorch_cndev_based_mlu_check_previous_value @@ -312,8 +314,10 @@ def is_torch_musa_available(check_device=False) -> bool: if check_device: try: # Will raise a RuntimeError if no MUSA is found - _ = torch.musa.device_count() - return torch.musa.is_available() + if hasattr(torch, "musa"): + _ = torch.musa.device_count() + return torch.musa.is_available() + return False except RuntimeError: return False return hasattr(torch, "musa") and torch.musa.is_available() @@ -451,12 +455,12 @@ def is_torch_bf16_gpu_available() -> bool: if is_torch_hpu_available(): return True if is_torch_npu_available(): - return torch.npu.is_bf16_supported() + return torch.npu.is_bf16_supported() if hasattr(torch, "npu") else False if is_torch_mps_available(): # Note: Emulated in software by Metal using fp32 for hardware without native support (like M1/M2) return torch.backends.mps.is_macos_or_newer(14, 0) if is_torch_musa_available(): - return torch.musa.is_bf16_supported() + return torch.musa.is_bf16_supported() if hasattr(torch, "musa") else False return False @@ -516,9 +520,10 @@ def is_torch_tf32_available() -> bool: import torch if is_torch_musa_available(): - device_info = torch.musa.get_device_properties(torch.musa.current_device()) - if f"{device_info.major}{device_info.minor}" >= "22": - return True + if hasattr(torch, "musa"): + device_info = torch.musa.get_device_properties(torch.musa.current_device()) + if f"{device_info.major}{device_info.minor}" >= "22": + return True return False if not torch.cuda.is_available() or torch.version.cuda is None: return False @@ -541,10 +546,12 @@ def enable_tf32(enable: bool) -> None: pytorch_version = version.parse(get_torch_version()) if pytorch_version >= version.parse("2.9.0"): precision_mode = "tf32" if enable else "ieee" - torch.backends.fp32_precision = precision_mode + if hasattr(torch.backends, "fp32_precision"): + torch.backends.fp32_precision = precision_mode else: if is_torch_musa_available(): - torch.backends.mudnn.allow_tf32 = enable + if hasattr(torch.backends, "mudnn"): + torch.backends.mudnn.allow_tf32 = enable else: torch.backends.cuda.matmul.allow_tf32 = enable torch.backends.cudnn.allow_tf32 = enable @@ -2012,7 +2019,7 @@ def __init__( # Needed for autocompletion in an IDE def __dir__(self): - result = super().__dir__() + result = list(super().__dir__()) # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir. for attr in self.__all__: @@ -2266,10 +2273,12 @@ def direct_transformers_import(path: str, file="__init__.py") -> ModuleType: name = "transformers" location = os.path.join(path, file) spec = importlib.util.spec_from_file_location(name, location, submodule_search_locations=[path]) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - module = sys.modules[name] - return module + if spec is not None and spec.loader is not None: + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + module = sys.modules[name] + return module + raise ImportError(f"Could not load module {name} from {location}") class VersionComparison(Enum): @@ -2283,13 +2292,13 @@ class VersionComparison(Enum): @staticmethod def from_string(version_string: str) -> "VersionComparison": string_to_operator = { - "=": VersionComparison.EQUAL.value, - "==": VersionComparison.EQUAL.value, - "!=": VersionComparison.NOT_EQUAL.value, - ">": VersionComparison.GREATER_THAN.value, - "<": VersionComparison.LESS_THAN.value, - ">=": VersionComparison.GREATER_THAN_OR_EQUAL.value, - "<=": VersionComparison.LESS_THAN_OR_EQUAL.value, + "=": VersionComparison.EQUAL, + "==": VersionComparison.EQUAL, + "!=": VersionComparison.NOT_EQUAL, + ">": VersionComparison.GREATER_THAN, + "<": VersionComparison.LESS_THAN, + ">=": VersionComparison.GREATER_THAN_OR_EQUAL, + "<=": VersionComparison.LESS_THAN_OR_EQUAL, } return string_to_operator[version_string] diff --git a/src/transformers/utils/loading_report.py b/src/transformers/utils/loading_report.py index c613e3bc1acf..0e5810abd8c1 100644 --- a/src/transformers/utils/loading_report.py +++ b/src/transformers/utils/loading_report.py @@ -50,7 +50,7 @@ def update_key_name(mapping: dict[str, Any]) -> dict[str, Any]: mapping = {k: k for k in mapping} not_mapping = True - bucket: dict[tuple[str, Any], list[set[int]]] = defaultdict(list) + bucket: dict[str, list[set[int] | Any]] = defaultdict(list) for key, val in mapping.items(): digs = _DIGIT_RX.findall(key) patt = _pattern_of(key) diff --git a/src/transformers/utils/logging.py b/src/transformers/utils/logging.py index 17093880a3fa..93729ba4be2a 100644 --- a/src/transformers/utils/logging.py +++ b/src/transformers/utils/logging.py @@ -99,7 +99,8 @@ def _configure_library_root_logger() -> None: formatter = logging.Formatter("[%(levelname)s|%(pathname)s:%(lineno)s] %(asctime)s >> %(message)s") _default_handler.setFormatter(formatter) - is_ci = os.getenv("CI") is not None and os.getenv("CI").upper() in {"1", "ON", "YES", "TRUE"} + ci = os.getenv("CI") + is_ci = ci is not None and ci.upper() in {"1", "ON", "YES", "TRUE"} library_root_logger.propagate = is_ci @@ -312,7 +313,7 @@ def warning_advice(self, *args, **kwargs): self.warning(*args, **kwargs) -logging.Logger.warning_advice = warning_advice +logging.Logger.warning_advice = warning_advice # type: ignore[unresolved-attribute] @functools.lru_cache(None) @@ -327,7 +328,7 @@ def warning_once(self, *args, **kwargs): self.warning(*args, **kwargs) -logging.Logger.warning_once = warning_once # type: ignore[attr-defined] +logging.Logger.warning_once = warning_once # type: ignore[unresolved-attribute] @functools.lru_cache(None) @@ -342,7 +343,7 @@ def info_once(self, *args, **kwargs): self.info(*args, **kwargs) -logging.Logger.info_once = info_once # type: ignore[attr-defined] +logging.Logger.info_once = info_once # type: ignore[unresolved-attribute] class EmptyTqdm: diff --git a/src/transformers/utils/notebook.py b/src/transformers/utils/notebook.py index 1660d546ed1e..dcd2a79a9bdd 100644 --- a/src/transformers/utils/notebook.py +++ b/src/transformers/utils/notebook.py @@ -307,11 +307,12 @@ def on_train_begin(self, args, state, control, **kwargs): def on_step_end(self, args, state, control, **kwargs): epoch = int(state.epoch) if int(state.epoch) == state.epoch else f"{state.epoch:.2f}" - self.training_tracker.update( - state.global_step + 1, - comment=f"Epoch {epoch}/{state.num_train_epochs}", - force_update=self._force_next_update, - ) + if self.training_tracker is not None: + self.training_tracker.update( + state.global_step + 1, + comment=f"Epoch {epoch}/{state.num_train_epochs}", + force_update=self._force_next_update, + ) self._force_next_update = False def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): @@ -337,7 +338,8 @@ def on_log(self, args, state, control, logs=None, **kwargs): values = {"Training Loss": logs["loss"]} # First column is necessarily Step sine we're not in epoch eval strategy values["Step"] = state.global_step - self.training_tracker.write_line(values) + if self.training_tracker is not None: + self.training_tracker.write_line(values) def on_evaluate(self, args, state, control, metrics=None, **kwargs): if self.training_tracker is not None: @@ -351,6 +353,8 @@ def on_evaluate(self, args, state, control, metrics=None, **kwargs): values["Epoch"] = int(state.epoch) else: values["Step"] = state.global_step + if metrics is None: + metrics = {} metric_key_prefix = "eval" for k in metrics: if k.endswith("_loss"): @@ -374,9 +378,10 @@ def on_evaluate(self, args, state, control, metrics=None, **kwargs): self._force_next_update = True def on_train_end(self, args, state, control, **kwargs): - self.training_tracker.update( - state.global_step, - comment=f"Epoch {int(state.epoch)}/{state.num_train_epochs}", - force_update=True, - ) - self.training_tracker = None + if self.training_tracker is not None: + self.training_tracker.update( + state.global_step, + comment=f"Epoch {int(state.epoch)}/{state.num_train_epochs}", + force_update=True, + ) + self.training_tracker = None diff --git a/src/transformers/utils/peft_utils.py b/src/transformers/utils/peft_utils.py index 99062bf6502f..a1ec093b4e67 100644 --- a/src/transformers/utils/peft_utils.py +++ b/src/transformers/utils/peft_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +import importlib.metadata import os from packaging import version diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 68330c30f037..d78ca7d721f1 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -172,7 +172,10 @@ def to_json_string(self, use_diff: bool = True) -> str: `str`: String containing all the attributes that make up this configuration instance in JSON format. """ if use_diff is True: - config_dict = self.to_diff_dict() + if hasattr(self, "to_diff_dict"): + config_dict = self.to_diff_dict() + else: + config_dict = self.to_dict() else: config_dict = self.to_dict() return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" @@ -1255,7 +1258,11 @@ def is_quantized(self): def is_quantization_compressed(self): from compressed_tensors.quantization import QuantizationStatus - return self.is_quantized and self.quantization_config.quantization_status == QuantizationStatus.COMPRESSED + return ( + self.is_quantized + and self.quantization_config is not None + and self.quantization_config.quantization_status == QuantizationStatus.COMPRESSED + ) @property def is_sparsification_compressed(self): diff --git a/src/transformers/utils/type_validators.py b/src/transformers/utils/type_validators.py index 960d55d9dbfa..2600998277a6 100644 --- a/src/transformers/utils/type_validators.py +++ b/src/transformers/utils/type_validators.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Union +from typing import Any, Union, cast from ..tokenization_utils_base import PaddingStrategy, TruncationStrategy from ..video_utils import VideoMetadataType @@ -107,7 +107,7 @@ def check_dict_keys(d: dict[str, Any]) -> bool: for item in value: if not check_dict_keys(item): raise ValueError( - f"Invalid keys found in video metadata. Valid keys: {valid_keys} got: {list(item.keys())}" + f"Invalid keys found in video metadata. Valid keys: {valid_keys} got: {list(cast(dict, item).keys())}" ) elif isinstance(value, dict): From a83625b4e8cbc940ed32523adba5e7957a59e6cc Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Thu, 22 Jan 2026 09:20:35 +0100 Subject: [PATCH 06/27] causal_mask can be a Tensor, BlockMask or None, lets be explicit --- src/transformers/utils/attention_visualizer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/transformers/utils/attention_visualizer.py b/src/transformers/utils/attention_visualizer.py index c1253d36541d..a8967ac9b3fa 100644 --- a/src/transformers/utils/attention_visualizer.py +++ b/src/transformers/utils/attention_visualizer.py @@ -221,13 +221,14 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""): past_key_values=None, ) - if causal_mask is not None: - if hasattr(causal_mask, "bool"): - attention_mask = ~causal_mask.bool() - else: - attention_mask = ~causal_mask - else: + if causal_mask is None: + # attention_mask must be a tensor here attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, 1, seq_length, seq_length) + elif isinstance(causal_mask, torch.Tensor): + attention_mask = ~causal_mask.to(dtype=torch.bool) + else: + attention_mask = ~causal_mask + top_bottom_border = "##" * ( len(f"Attention visualization for {self.config.model_type} | {self.mapped_cls}") + 4 ) # Box width adjusted to text length From 152217d8c535bb201e02297ce50adb0c46d1bb11 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Thu, 22 Jan 2026 09:28:34 +0100 Subject: [PATCH 07/27] simplify to_json_string --- src/transformers/utils/quantization_config.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index d78ca7d721f1..b9ae09d53e3f 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -159,6 +159,12 @@ def __iter__(self): def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" + def to_diff_dict(self) -> dict[str, Any]: + """ + Default behavior: no diffing implemented for this config. + """ + return self.to_dict() + def to_json_string(self, use_diff: bool = True) -> str: """ Serializes this instance to a JSON string. @@ -171,13 +177,7 @@ def to_json_string(self, use_diff: bool = True) -> str: Returns: `str`: String containing all the attributes that make up this configuration instance in JSON format. """ - if use_diff is True: - if hasattr(self, "to_diff_dict"): - config_dict = self.to_diff_dict() - else: - config_dict = self.to_dict() - else: - config_dict = self.to_dict() + config_dict = self.to_diff_dict() if use_diff else self.to_dict() return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" def update(self, **kwargs): From 11d2679094e8df8cd5efb48be7157d2f457f490f Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Thu, 22 Jan 2026 09:35:02 +0100 Subject: [PATCH 08/27] make it more readable - we know qc cannot be None here but its best balance with ty --- src/transformers/utils/quantization_config.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index b9ae09d53e3f..a225ef067833 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1258,11 +1258,8 @@ def is_quantized(self): def is_quantization_compressed(self): from compressed_tensors.quantization import QuantizationStatus - return ( - self.is_quantized - and self.quantization_config is not None - and self.quantization_config.quantization_status == QuantizationStatus.COMPRESSED - ) + qc = self.quantization_config + return qc is not None and bool(qc.config_groups) and qc.quantization_status == QuantizationStatus.COMPRESSED @property def is_sparsification_compressed(self): From a6f5fa39d5bd25996b12d2745322db3e170f0c0b Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Thu, 22 Jan 2026 09:45:16 +0100 Subject: [PATCH 09/27] explicitely assert training_tracker, do not silently ignore --- src/transformers/utils/notebook.py | 98 +++++++++++++++--------------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/src/transformers/utils/notebook.py b/src/transformers/utils/notebook.py index dcd2a79a9bdd..8b59ad78531f 100644 --- a/src/transformers/utils/notebook.py +++ b/src/transformers/utils/notebook.py @@ -307,12 +307,12 @@ def on_train_begin(self, args, state, control, **kwargs): def on_step_end(self, args, state, control, **kwargs): epoch = int(state.epoch) if int(state.epoch) == state.epoch else f"{state.epoch:.2f}" - if self.training_tracker is not None: - self.training_tracker.update( - state.global_step + 1, - comment=f"Epoch {epoch}/{state.num_train_epochs}", - force_update=self._force_next_update, - ) + assert self.training_tracker is not None, "on_train_begin must be called before on_step_end" + self.training_tracker.update( + state.global_step + 1, + comment=f"Epoch {epoch}/{state.num_train_epochs}", + force_update=self._force_next_update, + ) self._force_next_update = False def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): @@ -338,50 +338,50 @@ def on_log(self, args, state, control, logs=None, **kwargs): values = {"Training Loss": logs["loss"]} # First column is necessarily Step sine we're not in epoch eval strategy values["Step"] = state.global_step - if self.training_tracker is not None: - self.training_tracker.write_line(values) + assert self.training_tracker is not None, "on_train_begin must be called before on_log" + self.training_tracker.write_line(values) def on_evaluate(self, args, state, control, metrics=None, **kwargs): - if self.training_tracker is not None: - values = {"Training Loss": "No log", "Validation Loss": "No log"} - for log in reversed(state.log_history): - if "loss" in log: - values["Training Loss"] = log["loss"] - break - - if self.first_column == "Epoch": - values["Epoch"] = int(state.epoch) - else: - values["Step"] = state.global_step - if metrics is None: - metrics = {} - metric_key_prefix = "eval" - for k in metrics: - if k.endswith("_loss"): - metric_key_prefix = re.sub(r"\_loss$", "", k) - _ = metrics.pop("total_flos", None) - _ = metrics.pop("epoch", None) - _ = metrics.pop(f"{metric_key_prefix}_runtime", None) - _ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None) - _ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None) - for k, v in metrics.items(): - splits = k.split("_") - name = " ".join([part.capitalize() for part in splits[1:]]) - if name == "Loss": - # Single dataset - name = "Validation Loss" - values[name] = v - self.training_tracker.write_line(values) - self.training_tracker.remove_child() - self.prediction_bar = None - # Evaluation takes a long time so we should force the next update. - self._force_next_update = True + assert self.training_tracker is not None, "on_train_begin must be called before on_evaluate" + values = {"Training Loss": "No log", "Validation Loss": "No log"} + for log in reversed(state.log_history): + if "loss" in log: + values["Training Loss"] = log["loss"] + break + + if self.first_column == "Epoch": + values["Epoch"] = int(state.epoch) + else: + values["Step"] = state.global_step + if metrics is None: + metrics = {} + metric_key_prefix = "eval" + for k in metrics: + if k.endswith("_loss"): + metric_key_prefix = re.sub(r"\_loss$", "", k) + _ = metrics.pop("total_flos", None) + _ = metrics.pop("epoch", None) + _ = metrics.pop(f"{metric_key_prefix}_runtime", None) + _ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None) + _ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None) + for k, v in metrics.items(): + splits = k.split("_") + name = " ".join([part.capitalize() for part in splits[1:]]) + if name == "Loss": + # Single dataset + name = "Validation Loss" + values[name] = v + self.training_tracker.write_line(values) + self.training_tracker.remove_child() + self.prediction_bar = None + # Evaluation takes a long time so we should force the next update. + self._force_next_update = True def on_train_end(self, args, state, control, **kwargs): - if self.training_tracker is not None: - self.training_tracker.update( - state.global_step, - comment=f"Epoch {int(state.epoch)}/{state.num_train_epochs}", - force_update=True, - ) - self.training_tracker = None + assert self.training_tracker is not None, "on_train_begin must be called before on_train_end" + self.training_tracker.update( + state.global_step, + comment=f"Epoch {int(state.epoch)}/{state.num_train_epochs}", + force_update=True, + ) + self.training_tracker = None From 2e5df020e3ccd6875c17fcc0aeb4177cde8a7a3c Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Thu, 22 Jan 2026 10:31:49 +0100 Subject: [PATCH 10/27] make _is_package_available return type unique, and fully private to its module --- .../integrations/integration_utils.py | 20 ++ .../models/dia/convert_dia_to_hf.py | 4 +- .../models/whisper/convert_openai_to_hf.py | 4 +- src/transformers/utils/import_utils.py | 214 +++++++++--------- 4 files changed, 131 insertions(+), 111 deletions(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 87865fb6b94a..7fe42297a5cc 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -35,6 +35,7 @@ import numpy as np import packaging.version +from transformers.utils.import_utils import is_pynvml_available if os.getenv("WANDB_MODE") == "offline": print("[INFO] Running in WANDB offline mode") @@ -56,6 +57,7 @@ if is_torch_available(): import torch + import torch.distributed as dist # comet_ml requires to be imported before any ML frameworks _MIN_COMET_VERSION = "3.43.2" @@ -998,6 +1000,24 @@ def on_log(self, args, state, control, model=None, logs=None, **kwargs): "total_flos", ] + if is_torch_available() and torch.cuda.is_available(): + device_idx = torch.cuda.current_device() + total_memory = torch.cuda.get_device_properties(device_idx).total_memory + memory_allocated = torch.cuda.memory_allocated(device_idx) + + gpu_memory_logs = { + f"gpu/{device_idx}/allocated_memory": memory_allocated / (1024**3), # GB + f"gpu/{device_idx}/memory_usage": memory_allocated / total_memory, # ratio + } + if is_pynvml_available(): + power = torch.cuda.power_draw(device_idx) + gpu_memory_logs[f"gpu/{device_idx}/power"] = power / 1000 # Watts + if dist.is_available() and dist.is_initialized(): + gathered_logs = [None] * dist.get_world_size() + dist.all_gather_object(gathered_logs, gpu_memory_logs) + gpu_memory_logs = {k: v for d in gathered_logs for k, v in d.items()} + else: + gpu_memory_logs = {} if not self._initialized: self.setup(args, state, model) if state.is_world_process_zero: diff --git a/src/transformers/models/dia/convert_dia_to_hf.py b/src/transformers/models/dia/convert_dia_to_hf.py index 732e71b54e32..067f176e1404 100644 --- a/src/transformers/models/dia/convert_dia_to_hf.py +++ b/src/transformers/models/dia/convert_dia_to_hf.py @@ -30,7 +30,7 @@ DiaTokenizer, GenerationConfig, ) -from transformers.utils.import_utils import _is_package_available +from transformers.utils.import_utils import is_tiktoken_available # Provide just the list of layer keys you want to fix @@ -180,7 +180,7 @@ def convert_dia_model_to_hf(checkpoint_path, verbose=False): model = convert_dia_model_to_hf(args.checkpoint_path, args.verbose) if args.convert_preprocessor: try: - if not _is_package_available("tiktoken"): + if not is_tiktoken_available(with_blobfile=False): raise ModuleNotFoundError( """`tiktoken` is not installed, use `pip install tiktoken` to convert the tokenizer""" ) diff --git a/src/transformers/models/whisper/convert_openai_to_hf.py b/src/transformers/models/whisper/convert_openai_to_hf.py index b30caab9f261..9a26ddb3a0f2 100755 --- a/src/transformers/models/whisper/convert_openai_to_hf.py +++ b/src/transformers/models/whisper/convert_openai_to_hf.py @@ -38,7 +38,7 @@ WhisperTokenizerFast, ) from transformers.models.whisper.tokenization_whisper import LANGUAGES, bytes_to_unicode -from transformers.utils.import_utils import _is_package_available +from transformers.utils.import_utils import is_tiktoken_available _MODELS = { @@ -345,7 +345,7 @@ def convert_tiktoken_to_hf( if args.convert_preprocessor: try: - if not _is_package_available("tiktoken"): + if not is_tiktoken_available(with_blobfile=False): raise ModuleNotFoundError( """`tiktoken` is not installed, use `pip install tiktoken` to convert the tokenizer""" ) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 33fe454f0f53..74d61e27ac4f 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -31,7 +31,7 @@ from functools import lru_cache from itertools import chain from types import ModuleType -from typing import Any, Literal, overload +from typing import Any import packaging.version from packaging import version @@ -45,15 +45,7 @@ PACKAGE_DISTRIBUTION_MAPPING = importlib.metadata.packages_distributions() -@overload -def _is_package_available(pkg_name: str, return_version: Literal[True]) -> tuple[bool, str]: ... - - -@overload -def _is_package_available(pkg_name: str, return_version: Literal[False] = False) -> bool: ... - - -def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: +def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str]: """Check if `pkg_name` exist, and optionally try to get its version""" spec = importlib.util.find_spec(pkg_name) package_exists = spec is not None @@ -79,10 +71,11 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[ package = importlib.import_module(pkg_name) package_version = getattr(package, "__version__", "N/A") logger.debug(f"Detected {pkg_name} version: {package_version}") + if return_version: return package_exists, package_version else: - return package_exists + return package_exists, None def is_env_variable_true(env_variable: str) -> bool: @@ -229,7 +222,7 @@ def is_torch_mps_available(min_version: str | None = None) -> bool: @lru_cache def is_torch_npu_available(check_device=False) -> bool: "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" - if not is_torch_available() or not _is_package_available("torch_npu"): + if not is_torch_available() or not _is_package_available("torch_npu")[0]: return False import torch @@ -278,7 +271,7 @@ def is_torch_mlu_available() -> bool: Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu uninitialized. """ - if not is_torch_available() or not _is_package_available("torch_mlu"): + if not is_torch_available() or not _is_package_available("torch_mlu")[0]: return False import torch @@ -300,7 +293,7 @@ def is_torch_mlu_available() -> bool: @lru_cache def is_torch_musa_available(check_device=False) -> bool: "Checks if `torch_musa` is installed and potentially if a MUSA is in the environment" - if not is_torch_available() or not _is_package_available("torch_musa"): + if not is_torch_available() or not _is_package_available("torch_musa")[0]: return False import torch @@ -331,7 +324,7 @@ def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False) -> bool: """ assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true." - torch_xla_available = USE_TORCH_XLA in ENV_VARS_TRUE_VALUES and _is_package_available("torch_xla") + torch_xla_available = USE_TORCH_XLA in ENV_VARS_TRUE_VALUES and _is_package_available("torch_xla")[0] if not torch_xla_available: return False @@ -350,8 +343,8 @@ def is_torch_hpu_available() -> bool: "Checks if `torch.hpu` is available and potentially if a HPU is in the environment" if ( not is_torch_available() - or not _is_package_available("habana_frameworks") - or not _is_package_available("habana_frameworks.torch") + or not _is_package_available("habana_frameworks")[0] + or not _is_package_available("habana_frameworks.torch")[0] ): return False @@ -569,7 +562,7 @@ def is_grouped_mm_available() -> bool: @lru_cache def is_kenlm_available() -> bool: - return _is_package_available("kenlm") + return _is_package_available("kenlm")[0] @lru_cache @@ -580,17 +573,17 @@ def is_kernels_available(MIN_VERSION: str = KERNELS_MIN_VERSION) -> bool: @lru_cache def is_cv2_available() -> bool: - return _is_package_available("cv2") + return _is_package_available("cv2")[0] @lru_cache def is_yt_dlp_available() -> bool: - return _is_package_available("yt_dlp") + return _is_package_available("yt_dlp")[0] @lru_cache def is_libcst_available() -> bool: - return _is_package_available("libcst") + return _is_package_available("libcst")[0] @lru_cache @@ -607,7 +600,7 @@ def is_triton_available(min_version: str = TRITON_MIN_VERSION) -> bool: @lru_cache def is_hadamard_available() -> bool: - return _is_package_available("fast_hadamard_transform") + return _is_package_available("fast_hadamard_transform")[0] @lru_cache @@ -618,12 +611,12 @@ def is_hqq_available(min_version: str = HQQ_MIN_VERSION) -> bool: @lru_cache def is_pygments_available() -> bool: - return _is_package_available("pygments") + return _is_package_available("pygments")[0] @lru_cache def is_torchvision_available() -> bool: - return _is_package_available("torchvision") + return _is_package_available("torchvision")[0] @lru_cache @@ -633,27 +626,27 @@ def is_torchvision_v2_available() -> bool: @lru_cache def is_galore_torch_available() -> bool: - return _is_package_available("galore_torch") + return _is_package_available("galore_torch")[0] @lru_cache def is_apollo_torch_available() -> bool: - return _is_package_available("apollo_torch") + return _is_package_available("apollo_torch")[0] @lru_cache def is_torch_optimi_available() -> bool: - return _is_package_available("optimi") + return _is_package_available("optimi")[0] @lru_cache def is_lomo_available() -> bool: - return _is_package_available("lomo_optim") + return _is_package_available("lomo_optim")[0] @lru_cache def is_grokadamw_available() -> bool: - return _is_package_available("grokadamw") + return _is_package_available("grokadamw")[0] @lru_cache @@ -664,47 +657,47 @@ def is_schedulefree_available(min_version: str = SCHEDULEFREE_MIN_VERSION) -> bo @lru_cache def is_pyctcdecode_available() -> bool: - return _is_package_available("pyctcdecode") + return _is_package_available("pyctcdecode")[0] @lru_cache def is_librosa_available() -> bool: - return _is_package_available("librosa") + return _is_package_available("librosa")[0] @lru_cache def is_essentia_available() -> bool: - return _is_package_available("essentia") + return _is_package_available("essentia")[0] @lru_cache def is_pydantic_available() -> bool: - return _is_package_available("pydantic") + return _is_package_available("pydantic")[0] @lru_cache def is_fastapi_available() -> bool: - return _is_package_available("fastapi") + return _is_package_available("fastapi")[0] @lru_cache def is_uvicorn_available() -> bool: - return _is_package_available("uvicorn") + return _is_package_available("uvicorn")[0] @lru_cache def is_openai_available() -> bool: - return _is_package_available("openai") + return _is_package_available("openai")[0] @lru_cache def is_pretty_midi_available() -> bool: - return _is_package_available("pretty_midi") + return _is_package_available("pretty_midi")[0] @lru_cache def is_mamba_ssm_available() -> bool: - return is_torch_cuda_available() and _is_package_available("mamba_ssm") + return is_torch_cuda_available() and _is_package_available("mamba_ssm")[0] @lru_cache @@ -721,37 +714,37 @@ def is_flash_linear_attention_available(): @lru_cache def is_causal_conv1d_available() -> bool: - return is_torch_cuda_available() and _is_package_available("causal_conv1d") + return is_torch_cuda_available() and _is_package_available("causal_conv1d")[0] @lru_cache def is_xlstm_available() -> bool: - return is_torch_available() and _is_package_available("xlstm") + return is_torch_available() and _is_package_available("xlstm")[0] @lru_cache def is_mambapy_available() -> bool: - return is_torch_available() and _is_package_available("mambapy") + return is_torch_available() and _is_package_available("mambapy")[0] @lru_cache def is_peft_available() -> bool: - return _is_package_available("peft") + return _is_package_available("peft")[0] @lru_cache def is_bs4_available() -> bool: - return _is_package_available("bs4") + return _is_package_available("bs4")[0] @lru_cache def is_coloredlogs_available() -> bool: - return _is_package_available("coloredlogs") + return _is_package_available("coloredlogs")[0] @lru_cache def is_onnx_available() -> bool: - return _is_package_available("onnx") + return _is_package_available("onnx")[0] @lru_cache @@ -762,22 +755,22 @@ def is_flute_available() -> bool: @lru_cache def is_g2p_en_available() -> bool: - return _is_package_available("g2p_en") + return _is_package_available("g2p_en")[0] @lru_cache def is_torch_neuroncore_available(check_device=True) -> bool: - return is_torch_xla_available() and _is_package_available("torch_neuronx") + return is_torch_xla_available() and _is_package_available("torch_neuronx")[0] @lru_cache def is_torch_tensorrt_fx_available() -> bool: - return _is_package_available("torch_tensorrt") and _is_package_available("torch_tensorrt.fx") + return _is_package_available("torch_tensorrt")[0] and _is_package_available("torch_tensorrt.fx")[0] @lru_cache def is_datasets_available() -> bool: - return _is_package_available("datasets") + return _is_package_available("datasets")[0] @lru_cache @@ -795,32 +788,32 @@ def is_detectron2_available() -> bool: @lru_cache def is_rjieba_available() -> bool: - return _is_package_available("rjieba") + return _is_package_available("rjieba")[0] @lru_cache def is_psutil_available() -> bool: - return _is_package_available("psutil") + return _is_package_available("psutil")[0] @lru_cache def is_py3nvml_available() -> bool: - return _is_package_available("py3nvml") + return _is_package_available("py3nvml")[0] @lru_cache def is_sacremoses_available() -> bool: - return _is_package_available("sacremoses") + return _is_package_available("sacremoses")[0] @lru_cache def is_apex_available() -> bool: - return _is_package_available("apex") + return _is_package_available("apex")[0] @lru_cache def is_aqlm_available() -> bool: - return _is_package_available("aqlm") + return _is_package_available("aqlm")[0] @lru_cache @@ -831,17 +824,17 @@ def is_vptq_available(min_version: str = VPTQ_MIN_VERSION) -> bool: @lru_cache def is_av_available() -> bool: - return _is_package_available("av") + return _is_package_available("av")[0] @lru_cache def is_decord_available() -> bool: - return _is_package_available("decord") + return _is_package_available("decord")[0] @lru_cache def is_torchcodec_available() -> bool: - return _is_package_available("torchcodec") + return _is_package_available("torchcodec")[0] @lru_cache @@ -888,7 +881,7 @@ def is_flash_attn_2_available() -> bool: @lru_cache def is_flash_attn_3_available() -> bool: - return is_torch_cuda_available() and _is_package_available("flash_attn_3") + return is_torch_cuda_available() and _is_package_available("flash_attn_3")[0] @lru_cache @@ -939,32 +932,32 @@ def is_quanto_greater(library_version: str, accept_dev: bool = False) -> bool: @lru_cache def is_torchdistx_available(): - return _is_package_available("torchdistx") + return _is_package_available("torchdistx")[0] @lru_cache def is_faiss_available() -> bool: - return _is_package_available("faiss") + return _is_package_available("faiss")[0] @lru_cache def is_scipy_available() -> bool: - return _is_package_available("scipy") + return _is_package_available("scipy")[0] @lru_cache def is_sklearn_available() -> bool: - return _is_package_available("sklearn") + return _is_package_available("sklearn")[0] @lru_cache def is_sentencepiece_available() -> bool: - return _is_package_available("sentencepiece") + return _is_package_available("sentencepiece")[0] @lru_cache def is_seqio_available() -> bool: - return _is_package_available("seqio") + return _is_package_available("seqio")[0] @lru_cache @@ -975,7 +968,7 @@ def is_gguf_available(min_version: str = GGUF_MIN_VERSION) -> bool: @lru_cache def is_protobuf_available() -> bool: - return _is_package_available("google") and _is_package_available("google.protobuf") + return _is_package_available("google")[0] and _is_package_available("google.protobuf")[0] @lru_cache @@ -985,12 +978,12 @@ def is_fsdp_available(min_version: str = FSDP_MIN_VERSION) -> bool: @lru_cache def is_optimum_available() -> bool: - return _is_package_available("optimum") + return _is_package_available("optimum")[0] @lru_cache def is_llm_awq_available() -> bool: - return _is_package_available("awq") + return _is_package_available("awq")[0] @lru_cache @@ -1001,12 +994,12 @@ def is_auto_round_available(min_version: str = AUTOROUND_MIN_VERSION) -> bool: @lru_cache def is_optimum_quanto_available(): - return is_optimum_available() and _is_package_available("optimum.quanto") + return is_optimum_available() and _is_package_available("optimum.quanto")[0] @lru_cache def is_quark_available() -> bool: - return _is_package_available("quark") + return _is_package_available("quark")[0] @lru_cache @@ -1023,7 +1016,7 @@ def is_qutlass_available(): @lru_cache def is_compressed_tensors_available() -> bool: - return _is_package_available("compressed_tensors") + return _is_package_available("compressed_tensors")[0] @lru_cache @@ -1033,87 +1026,87 @@ def is_sinq_available() -> bool: @lru_cache def is_gptqmodel_available() -> bool: - return _is_package_available("gptqmodel") + return _is_package_available("gptqmodel")[0] @lru_cache def is_fbgemm_gpu_available() -> bool: - return _is_package_available("fbgemm_gpu") + return _is_package_available("fbgemm_gpu")[0] @lru_cache def is_levenshtein_available() -> bool: - return _is_package_available("Levenshtein") + return _is_package_available("Levenshtein")[0] @lru_cache def is_optimum_neuron_available() -> bool: - return is_optimum_available() and _is_package_available("optimum.neuron") + return is_optimum_available() and _is_package_available("optimum.neuron")[0] @lru_cache def is_tokenizers_available() -> bool: - return _is_package_available("tokenizers") + return _is_package_available("tokenizers")[0] @lru_cache def is_vision_available() -> bool: - return _is_package_available("PIL") + return _is_package_available("PIL")[0] @lru_cache def is_pytesseract_available() -> bool: - return _is_package_available("pytesseract") + return _is_package_available("pytesseract")[0] @lru_cache def is_pytest_available() -> bool: - return _is_package_available("pytest") + return _is_package_available("pytest")[0] @lru_cache def is_pytest_order_available() -> bool: - return is_pytest_available() and _is_package_available("pytest_order") + return is_pytest_available() and _is_package_available("pytest_order")[0] @lru_cache def is_spacy_available() -> bool: - return _is_package_available("spacy") + return _is_package_available("spacy")[0] @lru_cache def is_pytorch_quantization_available() -> bool: - return _is_package_available("pytorch_quantization") + return _is_package_available("pytorch_quantization")[0] @lru_cache def is_pandas_available() -> bool: - return _is_package_available("pandas") + return _is_package_available("pandas")[0] @lru_cache def is_soundfile_available() -> bool: - return _is_package_available("soundfile") + return _is_package_available("soundfile")[0] @lru_cache def is_timm_available() -> bool: - return _is_package_available("timm") + return _is_package_available("timm")[0] @lru_cache def is_natten_available() -> bool: - return _is_package_available("natten") + return _is_package_available("natten")[0] @lru_cache def is_nltk_available() -> bool: - return _is_package_available("nltk") + return _is_package_available("nltk")[0] @lru_cache def is_numba_available() -> bool: - is_available = _is_package_available("numba") + is_available = _is_package_available("numba")[0] if not is_available: return False @@ -1123,7 +1116,7 @@ def is_numba_available() -> bool: @lru_cache def is_torchaudio_available() -> bool: - return _is_package_available("torchaudio") + return _is_package_available("torchaudio")[0] @lru_cache @@ -1140,22 +1133,22 @@ def is_speech_available() -> bool: @lru_cache def is_spqr_available() -> bool: - return _is_package_available("spqr_quant") + return _is_package_available("spqr_quant")[0] @lru_cache def is_phonemizer_available() -> bool: - return _is_package_available("phonemizer") + return _is_package_available("phonemizer")[0] @lru_cache def is_uroman_available() -> bool: - return _is_package_available("uroman") + return _is_package_available("uroman")[0] @lru_cache def is_sudachi_available() -> bool: - return _is_package_available("sudachipy") + return _is_package_available("sudachipy")[0] @lru_cache @@ -1166,37 +1159,39 @@ def is_sudachi_projection_available() -> bool: @lru_cache def is_jumanpp_available() -> bool: - return _is_package_available("rhoknp") and shutil.which("jumanpp") is not None + return _is_package_available("rhoknp")[0] and shutil.which("jumanpp") is not None @lru_cache def is_cython_available() -> bool: - return _is_package_available("pyximport") + return _is_package_available("pyximport")[0] @lru_cache def is_jinja_available() -> bool: - return _is_package_available("jinja2") + return _is_package_available("jinja2")[0] @lru_cache def is_jmespath_available() -> bool: - return _is_package_available("jmespath") + return _is_package_available("jmespath")[0] @lru_cache def is_mlx_available() -> bool: - return _is_package_available("mlx") + return _is_package_available("mlx")[0] @lru_cache def is_num2words_available() -> bool: - return _is_package_available("num2words") + return _is_package_available("num2words")[0] @lru_cache -def is_tiktoken_available() -> bool: - return _is_package_available("tiktoken") and _is_package_available("blobfile") +def is_tiktoken_available(with_blobfile: bool = True) -> bool: + if not _is_package_available("tiktoken")[0]: + return False + return with_blobfile and _is_package_available("blobfile")[0] or True @lru_cache @@ -1207,29 +1202,34 @@ def is_liger_kernel_available() -> bool: @lru_cache def is_rich_available() -> bool: - return _is_package_available("rich") + return _is_package_available("rich")[0] @lru_cache def is_matplotlib_available() -> bool: - return _is_package_available("matplotlib") + return _is_package_available("matplotlib")[0] @lru_cache def is_mistral_common_available() -> bool: - return _is_package_available("mistral_common") + return _is_package_available("mistral_common")[0] @lru_cache def is_opentelemetry_available() -> bool: try: - return _is_package_available("opentelemetry") and version.parse( + return _is_package_available("opentelemetry")[0] and version.parse( importlib.metadata.version("opentelemetry-api") ) >= version.parse("1.30.0") except Exception as _: return False +@lru_cache +def is_pynvml_available() -> bool: + return _is_package_available("pynvml")[0] + + def check_torch_load_is_safe() -> None: if not is_torch_greater_or_equal("2.6"): raise ValueError( @@ -1444,7 +1444,7 @@ def is_sagemaker_dp_enabled() -> bool: except json.JSONDecodeError: return False # Lastly, check if the `smdistributed` module is present. - return _is_package_available("smdistributed") + return _is_package_available("smdistributed")[0] def is_sagemaker_mp_enabled() -> bool: @@ -1468,7 +1468,7 @@ def is_sagemaker_mp_enabled() -> bool: except json.JSONDecodeError: return False # Lastly, check if the `smdistributed` module is present. - return _is_package_available("smdistributed") + return _is_package_available("smdistributed")[0] def is_training_run_on_sagemaker() -> bool: From eee45dae4a299d9dad4b37eb74af4063952a9353 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Thu, 22 Jan 2026 10:38:12 +0100 Subject: [PATCH 11/27] explicit contract at the base class level is nicer than duck typing checks --- src/transformers/utils/hub.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index fea444606279..b3d5c19f2984 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -709,6 +709,10 @@ def _upload_modified_files( revision=revision, ) + def save_pretrained(self, *args, **kwargs): + # explicit contract + raise NotImplementedError(f"{self.__class__.__name__} must implement `save_pretrained` to use `push_to_hub`.") + def push_to_hub( self, repo_id: str, @@ -774,10 +778,7 @@ def push_to_hub( with tempfile.TemporaryDirectory() as tmp_dir: # Save all files. - if hasattr(self, "save_pretrained"): - self.save_pretrained(tmp_dir, max_shard_size=max_shard_size) - else: - raise AttributeError("The object must have a save_pretrained method to use push_to_hub") + self.save_pretrained(tmp_dir, max_shard_size=max_shard_size) # Update model card model_card.save(os.path.join(tmp_dir, "README.md")) From 0e34af5d0a495ecce4169002ba51b5e70ba1711b Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Thu, 22 Jan 2026 10:47:28 +0100 Subject: [PATCH 12/27] add a comment about monky patching the logger --- src/transformers/utils/logging.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/utils/logging.py b/src/transformers/utils/logging.py index 93729ba4be2a..b2a220fd4d16 100644 --- a/src/transformers/utils/logging.py +++ b/src/transformers/utils/logging.py @@ -313,6 +313,8 @@ def warning_advice(self, *args, **kwargs): self.warning(*args, **kwargs) +# TODO: ideally we should have a new logger class, e.g. TransformerLogger that adds these new methods +# instead of monkey patching logging.Logger.warning_advice = warning_advice # type: ignore[unresolved-attribute] From 11b998d0900b82a9805887e55d98a445f1918b04 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Tue, 3 Feb 2026 14:37:52 +0100 Subject: [PATCH 13/27] added ignore --- src/transformers/utils/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index f94e17e59e41..98f0e575e7fd 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -878,7 +878,7 @@ def wrapper(self, *args, **kwargs): # Arg-specific handling if arg_name == "use_cache": if getattr(self, "gradient_checkpointing", False) and self.training and arg_value: - logger.warning_once( + logger.warning_once( # type: ignore[attr-defined] "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) arg_value = False From 4eb870db5ea55d70f3c46690cdc16d2033c53498 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Thu, 5 Feb 2026 08:57:09 +0100 Subject: [PATCH 14/27] better one --- src/transformers/utils/doc.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/utils/doc.py b/src/transformers/utils/doc.py index a2113fcf297c..0521d346d92c 100644 --- a/src/transformers/utils/doc.py +++ b/src/transformers/utils/doc.py @@ -1091,7 +1091,6 @@ def copy_func(f): """Returns a copy of a function f.""" # Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard) g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__) - wrapped = functools.update_wrapper(g, f) - if hasattr(f, "__kwdefaults__"): - setattr(wrapped, "__kwdefaults__", f.__kwdefaults__) - return wrapped + g = cast(types.FunctionType, functools.update_wrapper(g, f)) + g.__kwdefaults__ = f.__kwdefaults__ + return g From 070337af73ed7eb71aa6783514d7cbf099388e27 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Thu, 5 Feb 2026 08:57:43 +0100 Subject: [PATCH 15/27] fixed test --- src/transformers/utils/quantization_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index a225ef067833..b39a58890527 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1259,7 +1259,7 @@ def is_quantization_compressed(self): from compressed_tensors.quantization import QuantizationStatus qc = self.quantization_config - return qc is not None and bool(qc.config_groups) and qc.quantization_status == QuantizationStatus.COMPRESSED + return self.is_quantized and (qc is not None and qc.quantization_status == QuantizationStatus.COMPRESSED) @property def is_sparsification_compressed(self): From 243b60b4e5b03d5b697805fd3299ebd67376fd7f Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Thu, 5 Feb 2026 09:04:33 +0100 Subject: [PATCH 16/27] tweaks --- src/transformers/utils/doc.py | 1 + src/transformers/utils/import_utils.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/utils/doc.py b/src/transformers/utils/doc.py index 0521d346d92c..4e46e4230bac 100644 --- a/src/transformers/utils/doc.py +++ b/src/transformers/utils/doc.py @@ -21,6 +21,7 @@ import textwrap import types from collections import OrderedDict +from typing import cast def get_docstring_indentation_level(func): diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 74d61e27ac4f..9889321add5c 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -114,7 +114,7 @@ def is_torch_available() -> bool: is_available, torch_version = _is_package_available("torch", return_version=True) parsed_version = version.parse(torch_version) if is_available and parsed_version < version.parse("2.4.0"): - logger.warning_once(f"Disabling PyTorch because PyTorch >= 2.4 is required but found {torch_version}") + logger.warning_once(f"Disabling PyTorch because PyTorch >= 2.4 is required but found {torch_version}") # type: ignore return is_available and version.parse(torch_version) >= version.parse("2.4.0") except packaging.version.InvalidVersion: return False From 09bc801cc81aca60eaf40f30ddbf72d954f2dc2b Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Thu, 5 Feb 2026 12:08:29 +0100 Subject: [PATCH 17/27] added a logger protocol --- src/transformers/utils/_typing.py | 123 ++++++++++++++++++ src/transformers/utils/chat_template_utils.py | 2 +- src/transformers/utils/generic.py | 2 +- src/transformers/utils/logging.py | 6 +- 4 files changed, 128 insertions(+), 5 deletions(-) create mode 100644 src/transformers/utils/_typing.py diff --git a/src/transformers/utils/_typing.py b/src/transformers/utils/_typing.py new file mode 100644 index 000000000000..5a9b2a09c949 --- /dev/null +++ b/src/transformers/utils/_typing.py @@ -0,0 +1,123 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import logging +from collections.abc import Mapping, MutableMapping +from typing import Any, Protocol, TypeAlias + + +# A few helpful type aliases +Level: TypeAlias = int +ExcInfo: TypeAlias = ( + None + | bool + | BaseException + | tuple[type[BaseException], BaseException, object] # traceback is `types.TracebackType`, but keep generic here +) + + +class TransformersLogger(Protocol): + # ---- Core Logger identity / configuration ---- + name: str + level: int + parent: logging.Logger | None + propagate: bool + disabled: bool + handlers: list[logging.Handler] + + # Exists on Logger; default is True. (Not heavily used, but is part of API.) + raiseExceptions: bool # type: ignore[assignment] + + # ---- Standard methods ---- + def setLevel(self, level: Level) -> None: ... + def isEnabledFor(self, level: Level) -> bool: ... + def getEffectiveLevel(self) -> int: ... + + def getChild(self, suffix: str) -> logging.Logger: ... + + def addHandler(self, hdlr: logging.Handler) -> None: ... + def removeHandler(self, hdlr: logging.Handler) -> None: ... + def hasHandlers(self) -> bool: ... + + # ---- Logging calls ---- + def debug(self, msg: object, *args: object, **kwargs: object) -> None: ... + def info(self, msg: object, *args: object, **kwargs: object) -> None: ... + def warning(self, msg: object, *args: object, **kwargs: object) -> None: ... + def warn(self, msg: object, *args: object, **kwargs: object) -> None: ... + def error(self, msg: object, *args: object, **kwargs: object) -> None: ... + def exception(self, msg: object, *args: object, exc_info: ExcInfo = True, **kwargs: object) -> None: ... + def critical(self, msg: object, *args: object, **kwargs: object) -> None: ... + def fatal(self, msg: object, *args: object, **kwargs: object) -> None: ... + + # The lowest-level primitive + def log(self, level: Level, msg: object, *args: object, **kwargs: object) -> None: ... + + # ---- Record-level / formatting ---- + def makeRecord( + self, + name: str, + level: Level, + fn: str, + lno: int, + msg: object, + args: tuple[object, ...] | Mapping[str, object], + exc_info: ExcInfo, + func: str | None = None, + extra: Mapping[str, object] | None = None, + sinfo: str | None = None, + ) -> logging.LogRecord: ... + + def handle(self, record: logging.LogRecord) -> None: ... + def findCaller( + self, + stack_info: bool = False, + stacklevel: int = 1, + ) -> tuple[str, int, str, str | None]: ... + + def callHandlers(self, record: logging.LogRecord) -> None: ... + def getMessage(self) -> str: ... # NOTE: actually on LogRecord; included rarely; safe to omit if you want + + def _log( + self, + level: Level, + msg: object, + args: tuple[object, ...] | Mapping[str, object], + exc_info: ExcInfo = None, + extra: Mapping[str, object] | None = None, + stack_info: bool = False, + stacklevel: int = 1, + ) -> None: ... + + # ---- Filters ---- + def addFilter(self, filt: logging.Filter) -> None: ... + def removeFilter(self, filt: logging.Filter) -> None: ... + @property + def filters(self) -> list[logging.Filter]: ... + + def filter(self, record: logging.LogRecord) -> bool: ... + + # ---- Convenience helpers ---- + def setFormatter(self, fmt: logging.Formatter) -> None: ... # mostly on handlers; present on adapters sometimes + def debugStack(self, msg: object, *args: object, **kwargs: object) -> None: ... # not std; safe no-op if absent + + # ---- stdlib dictConfig-friendly / extra storage ---- + # Logger has `manager` and can have arbitrary attributes; Protocol can't express arbitrary attrs, + # but we can at least include `__dict__` to make "extra attributes" less painful. + __dict__: MutableMapping[str, Any] + + # ---- Your monkey-patched methods ---- + def warning_advice(self, msg: object, *args: object, **kwargs: object) -> None: ... + def warning_once(self, msg: object, *args: object, **kwargs: object) -> None: ... + def info_once(self, msg: object, *args: object, **kwargs: object) -> None: ... diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 3231851a12c0..8df4b953fe6c 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -495,7 +495,7 @@ def render_jinja_template( **kwargs, ) -> str: if return_assistant_tokens_mask and not re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template): - logger.warning_once( # type: ignore[attr-defined] + logger.warning_once( "return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword." ) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 98f0e575e7fd..f94e17e59e41 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -878,7 +878,7 @@ def wrapper(self, *args, **kwargs): # Arg-specific handling if arg_name == "use_cache": if getattr(self, "gradient_checkpointing", False) and self.training and arg_value: - logger.warning_once( # type: ignore[attr-defined] + logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) arg_value = False diff --git a/src/transformers/utils/logging.py b/src/transformers/utils/logging.py index b2a220fd4d16..4b38b824ede7 100644 --- a/src/transformers/utils/logging.py +++ b/src/transformers/utils/logging.py @@ -33,6 +33,8 @@ import huggingface_hub.utils as hf_hub_utils from tqdm import auto as tqdm_lib +from ._typing import TransformersLogger + _lock = threading.Lock() _default_handler: logging.Handler | None = None @@ -144,7 +146,7 @@ def captureWarnings(capture): _captureWarnings(capture) -def get_logger(name: str | None = None) -> logging.Logger: +def get_logger(name: str | None = None) -> TransformersLogger: """ Return a logger with the specified name. @@ -313,8 +315,6 @@ def warning_advice(self, *args, **kwargs): self.warning(*args, **kwargs) -# TODO: ideally we should have a new logger class, e.g. TransformerLogger that adds these new methods -# instead of monkey patching logging.Logger.warning_advice = warning_advice # type: ignore[unresolved-attribute] From 5851b9a3e445fbecbfd82ead4e4cd2907534542a Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Thu, 5 Feb 2026 13:40:39 +0100 Subject: [PATCH 18/27] removed some asserts --- src/transformers/utils/notebook.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/transformers/utils/notebook.py b/src/transformers/utils/notebook.py index 8b59ad78531f..ecbe8271fe13 100644 --- a/src/transformers/utils/notebook.py +++ b/src/transformers/utils/notebook.py @@ -15,7 +15,7 @@ import os import re import time -from typing import Optional +from typing import Optional, TypeVar import IPython.display as disp @@ -23,6 +23,15 @@ from ..trainer_utils import IntervalStrategy, has_length +_T = TypeVar("_T") + + +def _require(x: _T | None, msg: str) -> _T: + if x is None: + raise RuntimeError(msg) + return x + + def format_time(t): "Format `t` (in seconds) to (h):mm:ss" t = int(t) @@ -307,8 +316,8 @@ def on_train_begin(self, args, state, control, **kwargs): def on_step_end(self, args, state, control, **kwargs): epoch = int(state.epoch) if int(state.epoch) == state.epoch else f"{state.epoch:.2f}" - assert self.training_tracker is not None, "on_train_begin must be called before on_step_end" - self.training_tracker.update( + tt = _require(self.training_tracker, "on_train_begin must be called before on_step_end") + tt.update( state.global_step + 1, comment=f"Epoch {epoch}/{state.num_train_epochs}", force_update=self._force_next_update, @@ -335,14 +344,15 @@ def on_predict(self, args, state, control, **kwargs): def on_log(self, args, state, control, logs=None, **kwargs): # Only for when there is no evaluation if args.eval_strategy == IntervalStrategy.NO and "loss" in logs: + tt = _require(self.training_tracker, "on_train_begin must be called before on_log") values = {"Training Loss": logs["loss"]} # First column is necessarily Step sine we're not in epoch eval strategy values["Step"] = state.global_step - assert self.training_tracker is not None, "on_train_begin must be called before on_log" - self.training_tracker.write_line(values) + tt.write_line(values) def on_evaluate(self, args, state, control, metrics=None, **kwargs): - assert self.training_tracker is not None, "on_train_begin must be called before on_evaluate" + tt = _require(self.training_tracker, "on_train_begin must be called before on_evaluate") + values = {"Training Loss": "No log", "Validation Loss": "No log"} for log in reversed(state.log_history): if "loss" in log: @@ -371,15 +381,15 @@ def on_evaluate(self, args, state, control, metrics=None, **kwargs): # Single dataset name = "Validation Loss" values[name] = v - self.training_tracker.write_line(values) - self.training_tracker.remove_child() + tt.write_line(values) + tt.remove_child() self.prediction_bar = None # Evaluation takes a long time so we should force the next update. self._force_next_update = True def on_train_end(self, args, state, control, **kwargs): - assert self.training_tracker is not None, "on_train_begin must be called before on_train_end" - self.training_tracker.update( + tt = _require(self.training_tracker, "on_train_begin must be called before on_train_end") + tt.update( state.global_step, comment=f"Epoch {int(state.epoch)}/{state.num_train_epochs}", force_update=True, From 1a782165ba3a6f74d92cea1106b21a1ea5e6ec56 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Thu, 5 Feb 2026 14:07:52 +0100 Subject: [PATCH 19/27] just ignore type check on that one --- src/transformers/utils/chat_template_utils.py | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 8df4b953fe6c..001e90a2ce22 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -22,14 +22,7 @@ from datetime import datetime from functools import lru_cache from inspect import isfunction -from typing import ( - Any, - Literal, - Union, - get_args, - get_origin, - get_type_hints, -) +from typing import Any, Literal, Union, get_args, get_origin, get_type_hints, no_type_check from packaging import version @@ -409,11 +402,15 @@ def _render_with_assistant_indices( @lru_cache def _compile_jinja_template(chat_template): + return _cached_compile_jinja_template(chat_template) + + +@no_type_check +def _cached_compile_jinja_template(chat_template): if not is_jinja_available(): raise ImportError( "apply_chat_template requires jinja2 to be installed. Please install it using `pip install jinja2`." ) - assert jinja2 is not None class AssistantTracker(Extension): # This extension is used to track the indices of assistant-generated tokens in the rendered chat @@ -426,17 +423,16 @@ def __init__(self, environment: ImmutableSandboxedEnvironment): self._rendered_blocks = None self._generation_indices = None - def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock: # type: ignore[name-defined] + def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock: lineno = next(parser.stream).lineno body = parser.parse_statements(["name:endgeneration"], drop_needle=True) - return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno) # type: ignore[attr-defined] + return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno) @jinja2.pass_eval_context def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str: rv = caller() if self.is_active(): # Only track generation indices if the tracker is active - assert self._rendered_blocks is not None and self._generation_indices is not None start_index = len("".join(self._rendered_blocks)) end_index = start_index + len(rv) self._generation_indices.append((start_index, end_index)) @@ -464,7 +460,6 @@ def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[ ) def raise_exception(message): - assert jinja2 is not None raise jinja2.exceptions.TemplateError(message) def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False): From f4055ed67b13771972d8d172cfbe5b3f100497d6 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Thu, 5 Feb 2026 14:11:46 +0100 Subject: [PATCH 20/27] better descirption --- src/transformers/utils/_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/_typing.py b/src/transformers/utils/_typing.py index 5a9b2a09c949..c98703340ee1 100644 --- a/src/transformers/utils/_typing.py +++ b/src/transformers/utils/_typing.py @@ -117,7 +117,7 @@ def debugStack(self, msg: object, *args: object, **kwargs: object) -> None: ... # but we can at least include `__dict__` to make "extra attributes" less painful. __dict__: MutableMapping[str, Any] - # ---- Your monkey-patched methods ---- + # ---- Transformers logger specific methods ---- def warning_advice(self, msg: object, *args: object, **kwargs: object) -> None: ... def warning_once(self, msg: object, *args: object, **kwargs: object) -> None: ... def info_once(self, msg: object, *args: object, **kwargs: object) -> None: ... From 396a24cf6e6d900cb9ba1094dbfedb3b8328dc60 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Fri, 6 Feb 2026 10:05:39 +0100 Subject: [PATCH 21/27] not needed anymore --- src/transformers/utils/import_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 9889321add5c..74d61e27ac4f 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -114,7 +114,7 @@ def is_torch_available() -> bool: is_available, torch_version = _is_package_available("torch", return_version=True) parsed_version = version.parse(torch_version) if is_available and parsed_version < version.parse("2.4.0"): - logger.warning_once(f"Disabling PyTorch because PyTorch >= 2.4 is required but found {torch_version}") # type: ignore + logger.warning_once(f"Disabling PyTorch because PyTorch >= 2.4 is required but found {torch_version}") return is_available and version.parse(torch_version) >= version.parse("2.4.0") except packaging.version.InvalidVersion: return False From e32710aaf2e688b78464221c2fad66de5b534650 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Fri, 6 Feb 2026 10:09:41 +0100 Subject: [PATCH 22/27] yeah callable dont always have names --- src/transformers/utils/chat_template_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 001e90a2ce22..53d438272bca 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -353,10 +353,10 @@ def get_json_schema(func: Callable) -> dict: } """ doc = inspect.getdoc(func) + func_name = getattr(func, "__name__", "operation") + if not doc: - raise DocstringParsingException( - f"Cannot generate JSON schema for {func.__name__} because it has no docstring!" # type: ignore[attr-defined] - ) + raise DocstringParsingException(f"Cannot generate JSON schema for {func_name} because it has no docstring!") doc = doc.strip() main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc) @@ -367,7 +367,7 @@ def get_json_schema(func: Callable) -> dict: for arg, schema in json_schema["properties"].items(): if arg not in param_descriptions: raise DocstringParsingException( - f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" # type: ignore[attr-defined] + f"Cannot generate JSON schema for {func_name} because the docstring has no description for the argument '{arg}'" ) desc = param_descriptions[arg] enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE) @@ -376,7 +376,7 @@ def get_json_schema(func: Callable) -> dict: desc = enum_choices.string[: enum_choices.start()].strip() schema["description"] = desc - output = {"name": func.__name__, "description": main_doc, "parameters": json_schema} # type: ignore[attr-defined] + output = {"name": func_name, "description": main_doc, "parameters": json_schema} if return_dict is not None: output["return"] = return_dict return {"type": "function", "function": output} From 91e998be46522549db9d8b1e7f33a30623714728 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Fri, 6 Feb 2026 10:20:47 +0100 Subject: [PATCH 23/27] one more func name ficx --- src/transformers/utils/chat_template_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 53d438272bca..a90073b39a54 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -178,6 +178,7 @@ def _parse_type_hint(hint: str) -> dict: def _convert_type_hints_to_json_schema(func: Callable) -> dict: type_hints = get_type_hints(func) signature = inspect.signature(func) + func_name = getattr(func, "__name__", "operation") # For methods, we need to ignore the first "self" or "cls" parameter. Here we assume that if the first parameter # is named "self" or "cls" and has no type hint, it is an implicit receiver argument. first_param_name = next(iter(signature.parameters), None) @@ -188,13 +189,12 @@ def _convert_type_hints_to_json_schema(func: Callable) -> dict: implicit_arg_name = first_param_name else: implicit_arg_name = None - required = [] for param_name, param in signature.parameters.items(): if param_name == implicit_arg_name: continue if param.annotation == inspect.Parameter.empty: - raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}") # type: ignore[attr-defined] + raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func_name}") if param.default == inspect.Parameter.empty: required.append(param_name) From 57c5ec7de639e599445fa10acce11fbd1d67a338 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Fri, 20 Feb 2026 08:33:55 +0100 Subject: [PATCH 24/27] more tweaks --- src/transformers/integrations/integration_utils.py | 1 + src/transformers/utils/generic.py | 2 +- src/transformers/utils/output_capturing.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 7fe42297a5cc..aea37ec7a9c4 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -37,6 +37,7 @@ from transformers.utils.import_utils import is_pynvml_available + if os.getenv("WANDB_MODE") == "offline": print("[INFO] Running in WANDB offline mode") diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index f94e17e59e41..51f30090f752 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -462,7 +462,7 @@ def _model_output_flatten(output: ModelOutput) -> tuple[list[Any], _torch_pytree def _model_output_unflatten( values: Iterable[Any], - context: "_torch_pytree.Context", + context: _torch_pytree.Context, output_type: type[ModelOutput] | None = None, ) -> ModelOutput: return output_type(**dict(zip(context, values))) diff --git a/src/transformers/utils/output_capturing.py b/src/transformers/utils/output_capturing.py index 5912df07cefd..aa5312212e29 100644 --- a/src/transformers/utils/output_capturing.py +++ b/src/transformers/utils/output_capturing.py @@ -176,7 +176,7 @@ def install_all_output_capturing_hooks(model: PreTrainedModel, prefix: str | Non prefix = prefix if prefix is not None else "" recursively_install_hooks(model, prefix, capture_tasks) # Mark the model as already hooked - model._output_capturing_hooks_installed = True + setattr(model, "_output_capturing_hooks_installed", True) # We need this to make sure we don't have race conditions when installing hooks, resulting in them being installed From 00a4c0f0594cd7ad343e5da578ee9f54d44122b1 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Tue, 3 Feb 2026 15:24:28 +0100 Subject: [PATCH 25/27] fix: VersionComparison.from_string() returns an enum not a str --- src/transformers/dynamic_module_utils.py | 2 +- src/transformers/utils/import_utils.py | 2 +- tests/utils/test_import_structure.py | 7 +++---- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index e94617d2f51e..9c9e7b929f6f 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -792,7 +792,7 @@ def check_python_requirements(path_or_repo_id, requirements_file="requirements.t continue if delimiter is not None and version_number is not None: - is_satisfied = VersionComparison.from_string(delimiter)( + is_satisfied = VersionComparison.from_string(delimiter).value( version.parse(local_package_version), version.parse(version_number) ) else: diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 74d61e27ac4f..01fc5d36c81d 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -2331,7 +2331,7 @@ def get_installed_version(self) -> str: return current_version def is_satisfied(self) -> bool: - return VersionComparison.from_string(self.version_comparison)( + return VersionComparison.from_string(self.version_comparison).value( version.parse(self.get_installed_version()), version.parse(self.version) ) diff --git a/tests/utils/test_import_structure.py b/tests/utils/test_import_structure.py index 0c3d8fff917b..0fb14501c728 100644 --- a/tests/utils/test_import_structure.py +++ b/tests/utils/test_import_structure.py @@ -1,6 +1,5 @@ import os import unittest -from collections.abc import Callable from pathlib import Path import pytest @@ -197,11 +196,11 @@ def test_import_spread(self): @pytest.mark.parametrize( "backend,package_name,version_comparison,version", [ - pytest.param(Backend("torch>=2.5 "), "torch", VersionComparison.GREATER_THAN_OR_EQUAL.value, "2.5"), - pytest.param(Backend("torchvision==0.19.1"), "torchvision", VersionComparison.EQUAL.value, "0.19.1"), + pytest.param(Backend("torch>=2.5 "), "torch", VersionComparison.GREATER_THAN_OR_EQUAL, "2.5"), + pytest.param(Backend("torchvision==0.19.1"), "torchvision", VersionComparison.EQUAL, "0.19.1"), ], ) -def test_backend_specification(backend: Backend, package_name: str, version_comparison: Callable, version: str): +def test_backend_specification(backend: Backend, package_name: str, version_comparison: VersionComparison, version: str): assert backend.package_name == package_name assert VersionComparison.from_string(backend.version_comparison) == version_comparison assert backend.version == version From 5c06349d573d72e26b9e720126b3112e2d052c0e Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Tue, 3 Feb 2026 15:25:04 +0100 Subject: [PATCH 26/27] format --- tests/utils/test_import_structure.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_import_structure.py b/tests/utils/test_import_structure.py index 0fb14501c728..fb48d35d5248 100644 --- a/tests/utils/test_import_structure.py +++ b/tests/utils/test_import_structure.py @@ -200,7 +200,9 @@ def test_import_spread(self): pytest.param(Backend("torchvision==0.19.1"), "torchvision", VersionComparison.EQUAL, "0.19.1"), ], ) -def test_backend_specification(backend: Backend, package_name: str, version_comparison: VersionComparison, version: str): +def test_backend_specification( + backend: Backend, package_name: str, version_comparison: VersionComparison, version: str +): assert backend.package_name == package_name assert VersionComparison.from_string(backend.version_comparison) == version_comparison assert backend.version == version From 9869da84952510b787f58855ad88fec19fd9fa6e Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Fri, 20 Feb 2026 16:29:46 +0100 Subject: [PATCH 27/27] Drop accidental Trackio pynvml/GPU tracking reintroduction --- .../integrations/integration_utils.py | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index aea37ec7a9c4..87865fb6b94a 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -35,8 +35,6 @@ import numpy as np import packaging.version -from transformers.utils.import_utils import is_pynvml_available - if os.getenv("WANDB_MODE") == "offline": print("[INFO] Running in WANDB offline mode") @@ -58,7 +56,6 @@ if is_torch_available(): import torch - import torch.distributed as dist # comet_ml requires to be imported before any ML frameworks _MIN_COMET_VERSION = "3.43.2" @@ -1001,24 +998,6 @@ def on_log(self, args, state, control, model=None, logs=None, **kwargs): "total_flos", ] - if is_torch_available() and torch.cuda.is_available(): - device_idx = torch.cuda.current_device() - total_memory = torch.cuda.get_device_properties(device_idx).total_memory - memory_allocated = torch.cuda.memory_allocated(device_idx) - - gpu_memory_logs = { - f"gpu/{device_idx}/allocated_memory": memory_allocated / (1024**3), # GB - f"gpu/{device_idx}/memory_usage": memory_allocated / total_memory, # ratio - } - if is_pynvml_available(): - power = torch.cuda.power_draw(device_idx) - gpu_memory_logs[f"gpu/{device_idx}/power"] = power / 1000 # Watts - if dist.is_available() and dist.is_initialized(): - gathered_logs = [None] * dist.get_world_size() - dist.all_gather_object(gathered_logs, gpu_memory_logs) - gpu_memory_logs = {k: v for d in gathered_logs for k, v in d.items()} - else: - gpu_memory_logs = {} if not self._initialized: self.setup(args, state, model) if state.is_world_process_zero: