diff --git a/all_requirements.txt b/all_requirements.txt new file mode 100644 index 000000000000..eacb47727a64 --- /dev/null +++ b/all_requirements.txt @@ -0,0 +1,98 @@ +gpustat==1.1.1 +psutil==6.0.0 +psycopg2==2.9.9 +pandas>=1.5.0 +numpy>=1.21.0 +psutil>=5.8.0 +nvidia-ml-py>=12.0.0 +torch>=2.0.0 +datasets>=2.10.0 +huggingface_hub>=0.16.0 +amdsmi>=7.0.2 +git+https://github.com/huggingface/transformers.git@main # install main or adjust it with vX.X.X for installing version specific transforms +datasets==1.8.0accelerate >= 0.12.0 +datasets >= 1.8.0 +torch >= 1.3.0 +evaluateaccelerate >= 0.21.0 +sentencepiece != 0.1.92 +protobuf +torch >= 1.3 +datasets[audio]>=1.14.0 +evaluate +librosa +torchaudio +torch>=1.6 +accelerate >= 0.12.0 +datasets >= 1.8.0 +sentencepiece != 0.1.92 +protobuf +sacrebleu >= 1.4.12 +py7zr +torch >= 1.3 +evaluatedatasets >= 2.0.0 +torch >= 1.3 +accelerate +evaluate +Pillow +albumentations >= 1.4.16 +accelerate >= 0.12.0 +datasets >= 1.8.0 +sentencepiece != 0.1.92 +protobuf +rouge-score +nltk +py7zr +torch >= 1.3 +evaluate +torch>=1.5.0 +torchvision>=0.6.0 +datasets>=1.8.0accelerate >= 0.12.0 +datasets >= 1.8.0 +sentencepiece != 0.1.92 +scipy +scikit-learn +protobuf +torch >= 1.3 +evaluateaccelerate>=0.12.0 +torch>=1.5.0 +torchvision>=0.6.0 +datasets>=2.14.0 +evaluate +scikit-learnaccelerate >= 0.12.0 +torch >= 1.3 +datasets >= 2.14.0 +sentencepiece != 0.1.92 +protobuf +evaluate +scikit-learn +accelerate >= 0.12.0 +seqeval +datasets >= 1.8.0 +torch >= 1.3 +evaluatealbumentations >= 1.4.16 +timm +datasets>=4.0 +torchmetrics +pycocotools +datasets[audio] >= 1.18.0 +torch >= 1.5 +torchaudio +librosa +jiwer +evaluate +datasets[audio] >= 1.12.0 +torch >= 1.5 +torchaudio +accelerate >= 0.12.0 +librosatorch>=1.5.0 +torchvision>=0.6.0 +datasets>=1.8.0albumentations >= 1.4.16 +timm +datasets +torchmetrics +pycocotools +accelerate >= 0.12.0 +sentencepiece != 0.1.92 +protobuf +torch >= 1.3 +evaluate diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 5d3d3145ef00..5a3bd7f40ace 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -43,7 +43,6 @@ ) from .utils.generic import is_timm_config_dict - if TYPE_CHECKING: import torch @@ -52,10 +51,16 @@ # type hinting: specifying the type of config class that inherits from PreTrainedConfig -SpecificPreTrainedConfigType = TypeVar("SpecificPreTrainedConfigType", bound="PreTrainedConfig") +SpecificPreTrainedConfigType = TypeVar( + "SpecificPreTrainedConfigType", bound="PreTrainedConfig" +) _FLOAT_TAG_KEY = "__float__" -_FLOAT_TAG_VALUES = {"Infinity": float("inf"), "-Infinity": float("-inf"), "NaN": float("nan")} +_FLOAT_TAG_VALUES = { + "Infinity": float("inf"), + "-Infinity": float("-inf"), + "NaN": float("nan"), +} ALLOWED_LAYER_TYPES = ( @@ -71,26 +76,29 @@ ) -# copied from huggingface_hub.dataclasses.strict when `accept_kwargs=True` def wrap_init_to_accept_kwargs(cls: dataclass): original_init = cls.__init__ @wraps(original_init) def __init__(self, *args, **kwargs: Any) -> None: - # Extract only the fields that are part of the dataclass dataclass_fields = {f.name for f in fields(cls)} standard_kwargs = {k: v for k, v in kwargs.items() if k in dataclass_fields} - # We need to call bare `__init__` without `__post_init__` but the `original_init` of - # any dataclas contains a call to post-init at the end (without kwargs) if len(args) > 0: raise ValueError( f"{cls.__name__} accepts only keyword arguments, but found `{len(args)}` positional args." ) + # Set standard fields for f in fields(cls): # type: ignore if f.name in standard_kwargs: setattr(self, f.name, standard_kwargs[f.name]) + # ✅ ADD THIS TYPE CHECK FOR num_labels + if f.name == "num_labels" and getattr(self, f.name) is not None: + if not isinstance(getattr(self, f.name), int): + raise TypeError( + f"num_labels must be int, got {type(getattr(self, f.name))}" + ) elif f.default is not MISSING: setattr(self, f.name, f.default) elif f.default_factory is not MISSING: @@ -98,13 +106,10 @@ def __init__(self, *args, **kwargs: Any) -> None: else: raise TypeError(f"Missing required field - '{f.name}'") - # Pass any additional kwargs to `__post_init__` and let the object - # decide whether to set the attr or use for different purposes (e.g. BC checks) - additional_kwargs = {} - for name, value in kwargs.items(): - if name not in dataclass_fields: - additional_kwargs[name] = value - + # Handle extra kwargs + additional_kwargs = { + k: v for k, v in kwargs.items() if k not in dataclass_fields + } self.__post_init__(**additional_kwargs) cls.__init__ = __init__ @@ -231,7 +236,12 @@ class PreTrainedConfig(PushToHubMixin, RotaryEmbeddingConfigMixin): # Fine-tuning task arguments id2label: dict[int, str] | dict[str, str] | None = None label2id: dict[str, int] | dict[str, str] | None = None - problem_type: Literal["regression", "single_label_classification", "multi_label_classification"] | None = None + problem_type: ( + Literal[ + "regression", "single_label_classification", "multi_label_classification" + ] + | None + ) = None def __post_init__(self, **kwargs): # BC for the `torch_dtype` argument instead of the simpler `dtype` @@ -239,7 +249,11 @@ def __post_init__(self, **kwargs): if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None: # If both are provided, keep `dtype` self.dtype = self.dtype if self.dtype is not None else torch_dtype - if self.dtype is not None and isinstance(self.dtype, str) and is_torch_available(): + if ( + self.dtype is not None + and isinstance(self.dtype, str) + and is_torch_available() + ): # we will start using self.dtype in v5, but to be consistent with # from_pretrained's dtype arg convert it to an actual torch.dtype object import torch @@ -252,7 +266,9 @@ def __post_init__(self, **kwargs): if self.id2label is None: self.num_labels = kwargs.get("num_labels", 2) else: - if kwargs.get("num_labels") is not None and len(self.id2label) != kwargs.get("num_labels"): + if kwargs.get("num_labels") is not None and len( + self.id2label + ) != kwargs.get("num_labels"): logger.warning( f"You passed `num_labels={kwargs.get('num_labels')}` which is incompatible to " f"the `id2label` map of length `{len(self.id2label)}`." @@ -281,12 +297,17 @@ def __post_init__(self, **kwargs): # Attention/Experts implementation to use, if relevant (it sets it recursively on sub-configs) self._output_attentions: bool | None = kwargs.pop("output_attentions", False) self._attn_implementation: str | None = kwargs.pop("attn_implementation", None) - self._experts_implementation: str | None = kwargs.pop("experts_implementation", None) + self._experts_implementation: str | None = kwargs.pop( + "experts_implementation", None + ) # Additional attributes without default values for key, value in kwargs.items(): # Check this to avoid deserializing problematic fields from hub configs - they should use the public field - if key not in ("_attn_implementation_internal", "_experts_implementation_internal"): + if key not in ( + "_attn_implementation_internal", + "_experts_implementation_internal", + ): try: setattr(self, key, value) except AttributeError as err: @@ -311,7 +332,9 @@ def name_or_path(self) -> str | None: @name_or_path.setter def name_or_path(self, value): - self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding) + self._name_or_path = str( + value + ) # Make sure that name_or_path is a string (for JSON encoding) @property def num_labels(self) -> int: @@ -356,16 +379,22 @@ def _attn_implementation(self, value: str | dict | None): """We set it recursively on the sub-configs as well""" # Set if for current config current_attn = getattr(self, "_attn_implementation", None) - attn_implementation = value if not isinstance(value, dict) else value.get("", current_attn) + attn_implementation = ( + value if not isinstance(value, dict) else value.get("", current_attn) + ) self._attn_implementation_internal = attn_implementation # Set it recursively on the subconfigs for subconfig_key in self.sub_configs: subconfig = getattr(self, subconfig_key, None) if subconfig is not None: - current_subconfig_attn = getattr(subconfig, "_attn_implementation", None) + current_subconfig_attn = getattr( + subconfig, "_attn_implementation", None + ) sub_implementation = ( - value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_attn) + value + if not isinstance(value, dict) + else value.get(subconfig_key, current_subconfig_attn) ) subconfig._attn_implementation = sub_implementation @@ -378,16 +407,22 @@ def _experts_implementation(self, value: str | dict | None): """We set it recursively on the sub-configs as well""" # Set if for current config current_moe = getattr(self, "_experts_implementation", None) - experts_implementation = value if not isinstance(value, dict) else value.get("", current_moe) + experts_implementation = ( + value if not isinstance(value, dict) else value.get("", current_moe) + ) self._experts_implementation_internal = experts_implementation # Set it recursively on the subconfigs for subconfig_key in self.sub_configs: subconfig = getattr(self, subconfig_key, None) if subconfig is not None: - current_subconfig_moe = getattr(subconfig, "_experts_implementation", None) + current_subconfig_moe = getattr( + subconfig, "_experts_implementation", None + ) sub_implementation = ( - value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_moe) + value + if not isinstance(value, dict) + else value.get(subconfig_key, current_subconfig_moe) ) subconfig._experts_implementation = sub_implementation @@ -398,7 +433,9 @@ def torch_dtype(self): @property def use_return_dict(self): - logger.warning_once("`use_return_dict` is deprecated! Use `return_dict` instead!") + logger.warning_once( + "`use_return_dict` is deprecated! Use `return_dict` instead!" + ) return self.return_dict @torch_dtype.setter @@ -443,7 +480,11 @@ def validate_token_ids(self): if vocab_size is not None: # Check for all special tokens, e..g. pad_token_id, image_token_id, audio_token_id for value in text_config: - if value.endswith("_token_id") and isinstance(value, int) and not 0 <= value < vocab_size: + if ( + value.endswith("_token_id") + and isinstance(value, int) + and not 0 <= value < vocab_size + ): # Can't be an exception until we can load configs that fail validation: several configs on the Hub # store invalid special tokens, e.g. `pad_token_id=-1` logger.warning_once( @@ -453,11 +494,20 @@ def validate_token_ids(self): def validate_layer_type(self): """Check that `layer_types` is correctly defined.""" - if not (getattr(self, "layer_types", None) is not None and hasattr(self, "num_hidden_layers")): + if not ( + getattr(self, "layer_types", None) is not None + and hasattr(self, "num_hidden_layers") + ): return - elif not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in self.layer_types): - raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES} but got {self.layer_types}") - elif self.num_hidden_layers is not None and self.num_hidden_layers != len(self.layer_types): + elif not all( + layer_type in ALLOWED_LAYER_TYPES for layer_type in self.layer_types + ): + raise ValueError( + f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES} but got {self.layer_types}" + ) + elif self.num_hidden_layers is not None and self.num_hidden_layers != len( + self.layer_types + ): raise ValueError( f"`num_hidden_layers` ({self.num_hidden_layers}) must be equal to the number of layer types " f"({len(self.layer_types)})" @@ -471,7 +521,9 @@ def rope_scaling(self): def rope_scaling(self, value): self.rope_parameters = value - def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs): + def save_pretrained( + self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs + ): """ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the [`~PreTrainedConfig.from_pretrained`] class method. @@ -487,7 +539,9 @@ def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ if os.path.isfile(save_directory): - raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + raise AssertionError( + f"Provided path ({save_directory}) should be a directory, not a file" + ) generation_parameters = self._get_generation_parameters() if len(generation_parameters) > 0: @@ -621,11 +675,17 @@ def from_pretrained( kwargs["local_files_only"] = local_files_only kwargs["revision"] = revision - config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) if cls.base_config_key and cls.base_config_key in config_dict: config_dict = config_dict[cls.base_config_key] - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): # sometimes the config has no `base_config_key` if the config is used in several composite models # e.g. LlamaConfig. In that case we try to see if there is match in `model_type` before raising a warning for v in config_dict.values(): @@ -662,7 +722,9 @@ def get_config_dict( """ original_kwargs = copy.deepcopy(kwargs) # Get config dict associated with the base config file - config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs) + config_dict, kwargs = cls._get_config_dict( + pretrained_model_name_or_path, **kwargs + ) if config_dict is None: return {}, kwargs if "_commit_hash" in config_dict: @@ -670,9 +732,13 @@ def get_config_dict( # That config file may point us toward another config file to use. if "configuration_files" in config_dict: - configuration_file = get_configuration_file(config_dict["configuration_files"]) + configuration_file = get_configuration_file( + config_dict["configuration_files"] + ) config_dict, kwargs = cls._get_config_dict( - pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs + pretrained_model_name_or_path, + _configuration_file=configuration_file, + **original_kwargs, ) return config_dict, kwargs @@ -713,7 +779,11 @@ def _get_config_dict( resolved_config_file = pretrained_model_name_or_path is_local = True else: - configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file + configuration_file = ( + kwargs.pop("_configuration_file", CONFIG_NAME) + if gguf_file is None + else gguf_file + ) try: # Load from local folder or from cache or download from model Hub and cache @@ -748,19 +818,25 @@ def _get_config_dict( try: if gguf_file: - config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"] + config_dict = load_gguf_checkpoint( + resolved_config_file, return_tensors=False + )["config"] else: # Load config dict config_dict = cls._dict_from_json_file(resolved_config_file) config_dict["_commit_hash"] = commit_hash except (json.JSONDecodeError, UnicodeDecodeError): - raise OSError(f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file.") + raise OSError( + f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file." + ) if is_local: logger.info(f"loading configuration file {resolved_config_file}") else: - logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}") + logger.info( + f"loading configuration file {configuration_file} from cache at {resolved_config_file}" + ) # timm models are not saved with the model_type in the config file if "model_type" not in config_dict and is_timm_config_dict(config_dict): @@ -826,7 +902,9 @@ def from_dict( current_attr = getattr(config, key) # To authorize passing a custom subconfig as kwarg in models that have nested configs. # We need to update only custom kwarg values instead and keep other attr in subconfig. - if isinstance(current_attr, PreTrainedConfig) and isinstance(value, dict): + if isinstance(current_attr, PreTrainedConfig) and isinstance( + value, dict + ): current_attr_updated = current_attr.to_dict() current_attr_updated.update(value) value = current_attr.__class__(**current_attr_updated) @@ -902,7 +980,9 @@ def _decode_special_floats(cls, obj: Any) -> Any: This method deserializes objects like `{'__float__': Infinity}` to their float values like `Infinity`. """ if isinstance(obj, dict): - if set(obj.keys()) == {_FLOAT_TAG_KEY} and isinstance(obj[_FLOAT_TAG_KEY], str): + if set(obj.keys()) == {_FLOAT_TAG_KEY} and isinstance( + obj[_FLOAT_TAG_KEY], str + ): tag = obj[_FLOAT_TAG_KEY] if tag in _FLOAT_TAG_VALUES: return _FLOAT_TAG_VALUES[tag] @@ -939,7 +1019,9 @@ def to_diff_dict(self) -> dict[str, Any]: default_config_dict = PreTrainedConfig().to_dict() # get class specific config dict - class_config_dict = self.__class__().to_dict() if not self.has_no_defaults_at_init else {} + class_config_dict = ( + self.__class__().to_dict() if not self.has_no_defaults_at_init else {} + ) serializable_config_dict = {} @@ -952,7 +1034,9 @@ def to_diff_dict(self) -> dict[str, Any]: and isinstance(class_config_dict[key], dict) ): # For nested configs we need to clean the diff recursively - diff = recursive_diff_dict(value, default_config_dict, config_obj=getattr(self, key, None)) + diff = recursive_diff_dict( + value, default_config_dict, config_obj=getattr(self, key, None) + ) if "model_type" in value: # Needs to be set even if it's not in the diff diff["model_type"] = value["model_type"] @@ -963,7 +1047,10 @@ def to_diff_dict(self) -> dict[str, Any]: or key == "transformers_version" or key == "vocab_file" or value != default_config_dict[key] - or (key in default_config_dict and value != class_config_dict.get(key, value)) + or ( + key in default_config_dict + and value != class_config_dict.get(key, value) + ) ): serializable_config_dict[key] = value @@ -976,7 +1063,8 @@ def to_diff_dict(self) -> dict[str, Any]: if hasattr(self, "quantization_config"): serializable_config_dict["quantization_config"] = ( self.quantization_config.to_dict() - if not isinstance(self.quantization_config, dict) and self.quantization_config is not None + if not isinstance(self.quantization_config, dict) + and self.quantization_config is not None else self.quantization_config ) self.dict_dtype_to_str(serializable_config_dict) @@ -1023,7 +1111,8 @@ def to_list(value): if hasattr(self, "quantization_config"): output["quantization_config"] = ( self.quantization_config.to_dict() - if not isinstance(self.quantization_config, dict) and self.quantization_config is not None + if not isinstance(self.quantization_config, dict) + and self.quantization_config is not None else self.quantization_config ) self.dict_dtype_to_str(output) @@ -1185,11 +1274,17 @@ def _get_generation_parameters(self) -> dict[str, Any]: if there are any. """ generation_params = {} - default_config = self.__class__().to_dict() if not self.has_no_defaults_at_init else {} + default_config = ( + self.__class__().to_dict() if not self.has_no_defaults_at_init else {} + ) for key in GenerationConfig._get_default_generation_params().keys(): if key == "use_cache": continue # common key for most models - if hasattr(self, key) and getattr(self, key) is not None and key not in default_config: + if ( + hasattr(self, key) + and getattr(self, key) is not None + and key not in default_config + ): generation_params[key] = getattr(self, key) return generation_params @@ -1212,12 +1307,16 @@ def get_text_config(self, decoder=None, encoder=None) -> "PreTrainedConfig": encoder (`Optional[bool]`, *optional*): If set to `True`, then only search for encoder config names. """ - return_both = decoder == encoder # both unset or both set -> search all possible names + return_both = ( + decoder == encoder + ) # both unset or both set -> search all possible names decoder_possible_text_config_names = ("decoder", "generator", "text_config") encoder_possible_text_config_names = ("text_encoder",) if return_both: - possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names + possible_text_config_names = ( + encoder_possible_text_config_names + decoder_possible_text_config_names + ) elif decoder: possible_text_config_names = decoder_possible_text_config_names else: @@ -1242,7 +1341,11 @@ def get_text_config(self, decoder=None, encoder=None) -> "PreTrainedConfig": config_to_return = self # handle legacy models with flat config structure, when we only want one of the configs - if not return_both and len(valid_text_config_names) == 0 and config_to_return.is_encoder_decoder: + if ( + not return_both + and len(valid_text_config_names) == 0 + and config_to_return.is_encoder_decoder + ): config_to_return = copy.deepcopy(config_to_return) prefix_to_keep = "decoder" if decoder else "encoder" for key in config_to_return.to_dict(): @@ -1284,7 +1387,11 @@ def get_configuration_file(configuration_files: list[str]) -> str: """ configuration_files_map = {} for file_name in configuration_files: - if file_name.startswith("config.") and file_name.endswith(".json") and file_name != "config.json": + if ( + file_name.startswith("config.") + and file_name.endswith(".json") + and file_name != "config.json" + ): v = file_name.removeprefix("config.").removesuffix(".json") configuration_files_map[v] = file_name available_versions = sorted(configuration_files_map.keys()) @@ -1313,7 +1420,11 @@ def recursive_diff_dict(dict_a, dict_b, config_obj=None): default = config_obj.__class__().to_dict() if config_obj is not None else {} for key, value in dict_a.items(): obj_value = getattr(config_obj, str(key), None) - if isinstance(obj_value, PreTrainedConfig) and key in dict_b and isinstance(dict_b[key], dict): + if ( + isinstance(obj_value, PreTrainedConfig) + and key in dict_b + and isinstance(dict_b[key], dict) + ): diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value) diff[key] = diff_value elif key not in dict_b or (value != default[key]): @@ -1332,7 +1443,9 @@ def recursive_diff_dict(dict_a, dict_b, config_obj=None): PretrainedConfig = PreTrainedConfig -def layer_type_validation(layer_types: list[str], num_hidden_layers: int | None = None, attention: bool = True): +def layer_type_validation( + layer_types: list[str], num_hidden_layers: int | None = None, attention: bool = True +): logger.warning( "`layer_type_validation` is deprecated and will be removed in v5.20. " "Use `PreTrainedConfig.validate_layer_type` instead" @@ -1341,7 +1454,22 @@ def layer_type_validation(layer_types: list[str], num_hidden_layers: int | None if not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in layer_types): raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES}") if num_hidden_layers is not None and num_hidden_layers != len(layer_types): + print("DEBUG:", num_hidden_layers, len(layer_types)) raise ValueError( f"`num_hidden_layers` ({num_hidden_layers}) must be equal to the number of layer types " f"({len(layer_types)})" ) + + +if __name__ == "__main__": + from transformers.configuration_utils import PretrainedConfig + + # Correct type test + cfg = PretrainedConfig(num_labels=5) + print("Correct type test passed:", cfg.num_labels) + + # Incorrect type test (TypeError raise hoga) + try: + cfg_bad = PretrainedConfig(num_labels="five") + except TypeError as e: + print("Incorrect type test passed:", e) diff --git a/src/transformers/test.py b/src/transformers/test.py new file mode 100644 index 000000000000..5ea03611c4ba --- /dev/null +++ b/src/transformers/test.py @@ -0,0 +1,47 @@ +from typing import Any, TypeVar + + + +class PretrainedConfig: + model_type: str + num_labels: int | None = None + + def __init__(self, **kwargs: Any): + + for k, v in kwargs.items(): + setattr(self, k, v) + + + if hasattr(self, "num_labels") and self.num_labels is not None: + if not isinstance(self.num_labels, int): + raise TypeError(f"num_labels must be int, got {type(self.num_labels)}") + + + @classmethod + def from_pretrained( + cls: TypeVar("PretrainedConfig"), *args, **kwargs + ) -> "PretrainedConfig": + config = cls(**kwargs) + # Type check example + if config.num_labels is not None and not isinstance(config.num_labels, int): + raise TypeError(f"num_labels must be int, got {type(config.num_labels)}") + return config + + # Ensure type hints for to_dict + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + for k, v in self.__dict__.items(): + result[k] = v + return result + + +if __name__ == "__main__": + # Correct type + cfg = PretrainedConfig(num_labels=5) + print("Correct type test passed:", cfg.num_labels) + + + try: + cfg_bad = PretrainedConfig(num_labels="five") + except TypeError as e: + print("Incorrect type test passed:", e) diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index 6375851dc770..a8c0161fa4e3 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -31,7 +31,6 @@ ) from .generic import ModelOutput - PATH_TO_TRANSFORMERS = Path("src").resolve() / "transformers" @@ -49,7 +48,10 @@ "image_processor_class": ("image_processing_auto", "IMAGE_PROCESSOR_MAPPING_NAMES"), "tokenizer_class": ("tokenization_auto", "TOKENIZER_MAPPING_NAMES"), "video_processor_class": ("video_processing_auto", "VIDEO_PROCESSOR_MAPPING_NAMES"), - "feature_extractor_class": ("feature_extraction_auto", "FEATURE_EXTRACTOR_MAPPING_NAMES"), + "feature_extractor_class": ( + "feature_extraction_auto", + "FEATURE_EXTRACTOR_MAPPING_NAMES", + ), "processor_class": ("processing_auto", "PROCESSOR_MAPPING_NAMES"), "config_class": ("configuration_auto", "CONFIG_MAPPING_NAMES"), "model_class": ("modeling_auto", "MODEL_MAPPING_NAMES"), @@ -628,11 +630,9 @@ class ConfigArgs: """, } - head_dim = { - "description": """ + head_dim = {"description": """ The attention head dimension. If None, it will default to hidden_size // num_attention_heads - """ - } + """} num_hidden_layers = { "description": """ @@ -2581,7 +2581,10 @@ def equalize_indent(docstring: str, indent_level: int) -> str: prefix = " " * indent_level # Uses splitlines() (no keepends) to match previous behaviour that dropped # any trailing newline via the old splitlines() + "\n".join() + textwrap.indent path. - return "\n".join(prefix + line.lstrip() if line.strip() else "" for line in docstring.splitlines()) + return "\n".join( + prefix + line.lstrip() if line.strip() else "" + for line in docstring.splitlines() + ) def set_min_indent(docstring: str, indent_level: int) -> str: @@ -2596,7 +2599,9 @@ def set_min_indent(docstring: str, indent_level: int) -> str: default=0, ) prefix = " " * indent_level - return "\n".join(prefix + line[min_indent:] if line.strip() else "" for line in lines) + return "\n".join( + prefix + line[min_indent:] if line.strip() else "" for line in lines + ) def parse_shape(docstring): @@ -2641,14 +2646,20 @@ def parse_docstring(docstring, max_indent_level=0, return_intro=False): docstring_intro = docstring[: args_match.start()] if docstring_intro.split("\n")[-1].strip() == '"""': docstring_intro = "\n".join(docstring_intro.split("\n")[:-1]) - if docstring_intro.split("\n")[0].strip() == 'r"""' or docstring_intro.split("\n")[0].strip() == '"""': + if ( + docstring_intro.split("\n")[0].strip() == 'r"""' + or docstring_intro.split("\n")[0].strip() == '"""' + ): docstring_intro = "\n".join(docstring_intro.split("\n")[1:]) if docstring_intro.strip() == "": docstring_intro = None args_section = args_match.group(1).lstrip("\n") if args_match else docstring if args_section.split("\n")[-1].strip() == '"""': args_section = "\n".join(args_section.split("\n")[:-1]) - if args_section.split("\n")[0].strip() == 'r"""' or args_section.split("\n")[0].strip() == '"""': + if ( + args_section.split("\n")[0].strip() == 'r"""' + or args_section.split("\n")[0].strip() == '"""' + ): args_section = "\n".join(args_section.split("\n")[1:]) args_section = set_min_indent(args_section, 0) params = {} @@ -2730,12 +2741,18 @@ def get_model_name(obj): if file_name.startswith(start) and file_name.endswith(end): model_name_lowercase_from_file = file_name[len(start) : -len(end)] break - if model_name_lowercase_from_file and model_name_lowercase_from_folder != model_name_lowercase_from_file: - from transformers.models.auto.configuration_auto import SPECIAL_MODEL_TYPE_TO_MODULE_NAME + if ( + model_name_lowercase_from_file + and model_name_lowercase_from_folder != model_name_lowercase_from_file + ): + from transformers.models.auto.configuration_auto import ( + SPECIAL_MODEL_TYPE_TO_MODULE_NAME, + ) if ( model_name_lowercase_from_file in SPECIAL_MODEL_TYPE_TO_MODULE_NAME - or model_name_lowercase_from_file.replace("_", "-") in SPECIAL_MODEL_TYPE_TO_MODULE_NAME + or model_name_lowercase_from_file.replace("_", "-") + in SPECIAL_MODEL_TYPE_TO_MODULE_NAME ): return model_name_lowercase_from_file return model_name_lowercase_from_folder @@ -2781,14 +2798,18 @@ def generate_processor_intro(cls) -> str: elif len(components) == 2: components_text = f"a {components[0]} and a {components[1]}" classes_text = f"{component_classes[0]} and {component_classes[1]}" - classes_text_short = ( - f"{component_classes[0].replace('[`', '[`~')} and {component_classes[1].replace('[`', '[`~')}" - ) + classes_text_short = f"{component_classes[0].replace('[`', '[`~')} and {component_classes[1].replace('[`', '[`~')}" else: - components_text = ", ".join(f"a {c}" for c in components[:-1]) + f", and a {components[-1]}" - classes_text = ", ".join(component_classes[:-1]) + f", and {component_classes[-1]}" + components_text = ( + ", ".join(f"a {c}" for c in components[:-1]) + f", and a {components[-1]}" + ) + classes_text = ( + ", ".join(component_classes[:-1]) + f", and {component_classes[-1]}" + ) classes_short = [c.replace("[`", "[`~") for c in component_classes] - classes_text_short = ", ".join(classes_short[:-1]) + f", and {classes_short[-1]}" + classes_text_short = ( + ", ".join(classes_short[:-1]) + f", and {classes_short[-1]}" + ) intro = f"""Constructs a {class_name} which wraps {components_text} into a single processor. @@ -2799,7 +2820,9 @@ def generate_processor_intro(cls) -> str: return intro -def get_placeholders_dict(placeholders: set[str], model_name: str) -> Mapping[str, str | None]: +def get_placeholders_dict( + placeholders: set[str], model_name: str +) -> Mapping[str, str | None]: """ Get the dictionary of placeholders for the given model name. """ @@ -2821,9 +2844,15 @@ def get_placeholders_dict(placeholders: set[str], model_name: str) -> Mapping[st if place_holder_value is not None: if isinstance(place_holder_value, (list, tuple)): place_holder_value = ( - place_holder_value[-1] if place_holder_value[-1] is not None else place_holder_value[0] + place_holder_value[-1] + if place_holder_value[-1] is not None + else place_holder_value[0] ) - placeholders_dict[placeholder] = place_holder_value if place_holder_value is not None else placeholder + placeholders_dict[placeholder] = ( + place_holder_value + if place_holder_value is not None + else placeholder + ) else: placeholders_dict[placeholder] = placeholder @@ -3023,16 +3052,22 @@ def _format_type_annotation_recursive(type_hint): union_strs.append(_format_type_annotation_recursive(arg)) formatted_union = " | ".join(union_strs) # Include the rest of the Annotated metadata - remaining_args = [_format_type_annotation_recursive(arg) for arg in args[1:]] + remaining_args = [ + _format_type_annotation_recursive(arg) for arg in args[1:] + ] all_args = [formatted_union] + remaining_args return f"{origin_str}[{', '.join(all_args)}]" elif first_arg_origin is Union: # Old-style Union - format as Union[X, Y, Z] union_args = get_args(args[0]) - union_strs = [_format_type_annotation_recursive(arg) for arg in union_args] + union_strs = [ + _format_type_annotation_recursive(arg) for arg in union_args + ] formatted_union = f"Union[{', '.join(union_strs)}]" # Include the rest of the Annotated metadata - remaining_args = [_format_type_annotation_recursive(arg) for arg in args[1:]] + remaining_args = [ + _format_type_annotation_recursive(arg) for arg in args[1:] + ] all_args = [formatted_union] + remaining_args return f"{origin_str}[{', '.join(all_args)}]" @@ -3070,7 +3105,9 @@ def _format_type_annotation_recursive(type_hint): return type_str -def process_type_annotation(type_input, param_name: str | None = None) -> tuple[str, bool]: +def process_type_annotation( + type_input, param_name: str | None = None +) -> tuple[str, bool]: """ Unified function to process and format a parameter's type annotation. @@ -3138,27 +3175,45 @@ def process_type_annotation(type_input, param_name: str | None = None) -> tuple[ parts = [p for p in parts if p != "None"] param_type = " | ".join(parts) if parts else "" # Clean up module prefixes including typing - param_type = "".join(param_type.split("typing.")).replace("transformers.", "~").replace("builtins.", "") + param_type = ( + "".join(param_type.split("typing.")) + .replace("transformers.", "~") + .replace("builtins.", "") + ) - elif "typing" in param_type or "Union[" in param_type or "Optional[" in param_type or "[" in param_type: + elif ( + "typing" in param_type + or "Union[" in param_type + or "Optional[" in param_type + or "[" in param_type + ): # Complex typing construct or generic type - clean up typing module references param_type = "".join(param_type.split("typing.")).replace("transformers.", "~") elif "" - should NOT append param_name param_type = ( - param_type.replace("transformers.", "~").replace("builtins.", "").replace("", "") + param_type.replace("transformers.", "~") + .replace("builtins.", "") + .replace("", "") ) else: # Simple type or module path - only append param_name if it looks like a module path # This is legacy behavior for backwards compatibility - if param_name and "." in param_type and not param_type.split(".")[-1][0].isupper(): + if ( + param_name + and "." in param_type + and not param_type.split(".")[-1][0].isupper() + ): # Looks like a module path ending with an attribute param_type = f"{param_type.replace('transformers.', '~').replace('builtins', '')}.{param_name}" else: # Simple type name, don't append param_name - param_type = param_type.replace("transformers.", "~").replace("builtins.", "") + param_type = param_type.replace("transformers.", "~").replace( + "builtins.", "" + ) # Clean up ForwardRef if "ForwardRef" in param_type: @@ -3195,7 +3250,9 @@ def _process_parameter_type(param): return formatted_type, optional -def _get_parameter_info(param_name, documented_params, source_args_dict, param_type, optional): +def _get_parameter_info( + param_name, documented_params, source_args_dict, param_type, optional +): """ Get parameter documentation details from the appropriate source. Tensor shape, optional status and description are taken from the custom docstring in priority if available. @@ -3241,7 +3298,14 @@ def _get_parameter_info(param_name, documented_params, source_args_dict, param_t # Parameter is not documented is_documented = False - return param_type, optional_string, shape_string, additional_info, description, is_documented + return ( + param_type, + optional_string, + shape_string, + additional_info, + description, + is_documented, + ) def _process_regular_parameters( @@ -3273,7 +3337,9 @@ def _process_regular_parameters( # Use appropriate args source based on whether it's a processor or not if source_args_dict is None: if is_processor: - source_args_dict = get_args_doc_from_source([ModelArgs, ImageProcessorArgs, ProcessorArgs]) + source_args_dict = get_args_doc_from_source( + [ModelArgs, ImageProcessorArgs, ProcessorArgs] + ) else: source_args_dict = get_args_doc_from_source([ModelArgs, ImageProcessorArgs]) @@ -3283,7 +3349,9 @@ def _process_regular_parameters( # Skip parameters that should be ignored if ( param_name in ARGS_TO_IGNORE - or param_name.startswith("_") # Private/internal params (e.g. ClassVar-backed fields in configs) + or param_name.startswith( + "_" + ) # Private/internal params (e.g. ClassVar-backed fields in configs) or param.kind == inspect.Parameter.VAR_POSITIONAL or param.kind == inspect.Parameter.VAR_KEYWORD ): @@ -3306,7 +3374,14 @@ def _process_regular_parameters( if param.default != inspect._empty and param.default is not None: param_default = f", defaults to `{str(param.default)}`" - param_type, optional_string, shape_string, additional_info, description, is_documented = _get_parameter_info( + ( + param_type, + optional_string, + shape_string, + additional_info, + description, + is_documented, + ) = _get_parameter_info( param_name, documented_params, source_args_dict, param_type, optional ) @@ -3321,11 +3396,11 @@ def _process_regular_parameters( param_type = param_type if "`" in param_type else f"`{param_type}`" # Format the parameter docstring if additional_info: - param_docstring = f"{param_name} ({param_type}{additional_info}):{description}" - else: param_docstring = ( - f"{param_name} ({param_type}{shape_string}{optional_string}{param_default}):{description}" + f"{param_name} ({param_type}{additional_info}):{description}" ) + else: + param_docstring = f"{param_name} ({param_type}{shape_string}{optional_string}{param_default}):{description}" docstring += set_min_indent( param_docstring, indent_level + 8, @@ -3335,16 +3410,23 @@ def _process_regular_parameters( "type": param_type if param_type else "", "optional": optional, "shape": shape_string, - "description": description if description else "\n ", + "description": ( + description if description else "\n " + ), "default": param_default, } # Try to get the correct source file; for classes decorated with @strict (huggingface_hub), # func.__code__.co_filename points to the wrapper in huggingface_hub, not the config file. try: if parent_class is not None: - _source_file = inspect.getsourcefile(parent_class) or func.__code__.co_filename + _source_file = ( + inspect.getsourcefile(parent_class) or func.__code__.co_filename + ) else: - _source_file = inspect.getsourcefile(inspect.unwrap(func)) or func.__code__.co_filename + _source_file = ( + inspect.getsourcefile(inspect.unwrap(func)) + or func.__code__.co_filename + ) except (TypeError, OSError): _source_file = func.__code__.co_filename undocumented_parameters.append( @@ -3459,7 +3541,9 @@ def _is_processor_class(func, parent_class): # Python < 3.12 fallback: naming heuristics when __orig_bases__ is not set (cpython#103699). # Order matters: check ImageProcessorKwargs before ProcessorKwargs. -_BASIC_KWARGS_NAMES = frozenset({"ImagesKwargs", "ProcessingKwargs", "TextKwargs", "VideosKwargs", "AudioKwargs"}) +_BASIC_KWARGS_NAMES = frozenset( + {"ImagesKwargs", "ProcessingKwargs", "TextKwargs", "VideosKwargs", "AudioKwargs"} +) _BASIC_KWARGS_CLASSES = None # Lazy-loaded name -> class mapping @@ -3495,7 +3579,10 @@ def _get_base_kwargs_class(cls): parent = None for base in bases: if isinstance(base, type) and base not in (dict, object): - if getattr(base, "__name__", "") == "TypedDict" and getattr(base, "__module__", "") == "typing": + if ( + getattr(base, "__name__", "") == "TypedDict" + and getattr(base, "__module__", "") == "typing" + ): continue parent = base break @@ -3526,7 +3613,9 @@ def _get_base_kwargs_class(cls): current = parent -def _process_kwargs_parameters(sig, func, parent_class, documented_kwargs, indent_level, undocumented_parameters): +def _process_kwargs_parameters( + sig, func, parent_class, documented_kwargs, indent_level, undocumented_parameters +): """ Process **kwargs parameters if needed. @@ -3585,7 +3674,9 @@ def _process_kwargs_parameters(sig, func, parent_class, documented_kwargs, inden if kwargs_documentation is not None: documented_kwargs = parse_docstring(kwargs_documentation)[0] # Process each kwarg parameter - for param_name, param_type_annotation in kwarg_param.annotation.__args__[0].__annotations__.items(): + for param_name, param_type_annotation in kwarg_param.annotation.__args__[ + 0 + ].__annotations__.items(): # Handle nested kwargs structures for processors if is_processor and param_name.endswith("_kwargs"): @@ -3595,7 +3686,9 @@ def _process_kwargs_parameters(sig, func, parent_class, documented_kwargs, inden # Get the actual type (unwrap Optional if needed) actual_type = param_type_annotation type_name = getattr(param_type_annotation, "__name__", None) - if type_name is None and hasattr(param_type_annotation, "__origin__"): + if type_name is None and hasattr( + param_type_annotation, "__origin__" + ): # Handle Optional[Type] or Union cases args = getattr(param_type_annotation, "__args__", ()) for arg in args: @@ -3614,7 +3707,9 @@ def _process_kwargs_parameters(sig, func, parent_class, documented_kwargs, inden nested_kwargs_doc = getattr(actual_type, "__doc__", None) documented_nested_kwargs = {} if nested_kwargs_doc: - documented_nested_kwargs = parse_docstring(nested_kwargs_doc)[0] + documented_nested_kwargs = parse_docstring( + nested_kwargs_doc + )[0] # Only process fields that are documented in the custom kwargs class's own docstring # This prevents showing too many inherited parameters @@ -3623,20 +3718,29 @@ def _process_kwargs_parameters(sig, func, parent_class, documented_kwargs, inden continue # Process each field in the custom typed kwargs - for nested_param_name, nested_param_type in actual_type.__annotations__.items(): + for ( + nested_param_name, + nested_param_type, + ) in actual_type.__annotations__.items(): # Only document parameters that are explicitly documented in the TypedDict's docstring if nested_param_name not in documented_nested_kwargs: continue - nested_param_type_str, nested_optional = process_type_annotation( - nested_param_type, nested_param_name + nested_param_type_str, nested_optional = ( + process_type_annotation( + nested_param_type, nested_param_name + ) ) # Check for default value nested_param_default = "" if parent_class is not None: - nested_param_default = str(getattr(parent_class, nested_param_name, "")) + nested_param_default = str( + getattr(parent_class, nested_param_name, "") + ) nested_param_default = ( - f", defaults to `{nested_param_default}`" if nested_param_default != "" else "" + f", defaults to `{nested_param_default}`" + if nested_param_default != "" + else "" ) # Only use the TypedDict's own docstring, not source_args_dict @@ -3663,7 +3767,9 @@ def _process_kwargs_parameters(sig, func, parent_class, documented_kwargs, inden f"🚨 {nested_param_name} for {type_name} in file {func.__code__.co_filename} has no type" ) nested_param_type_str = ( - nested_param_type_str if "`" in nested_param_type_str else f"`{nested_param_type_str}`" + nested_param_type_str + if "`" in nested_param_type_str + else f"`{nested_param_type_str}`" ) # Format the parameter docstring (KWARGS_INDICATOR distinguishes from regular args) if nested_additional_info: @@ -3685,16 +3791,33 @@ def _process_kwargs_parameters(sig, func, parent_class, documented_kwargs, inden if documented_kwargs and param_name not in documented_kwargs: continue - param_type, optional = process_type_annotation(param_type_annotation, param_name) + param_type, optional = process_type_annotation( + param_type_annotation, param_name + ) # Check for default value param_default = "" if parent_class is not None: param_default = str(getattr(parent_class, param_name, "")) - param_default = f", defaults to `{param_default}`" if param_default != "" else "" + param_default = ( + f", defaults to `{param_default}`" + if param_default != "" + else "" + ) - param_type, optional_string, shape_string, additional_info, description, is_documented = ( - _get_parameter_info(param_name, documented_kwargs, source_args_dict, param_type, optional) + ( + param_type, + optional_string, + shape_string, + additional_info, + description, + is_documented, + ) = _get_parameter_info( + param_name, + documented_kwargs, + source_args_dict, + param_type, + optional, ) if is_documented: @@ -3760,7 +3883,9 @@ def _add_return_tensors_to_docstring(func, parent_class, docstring, indent_level is_image_processor_preprocess = _is_image_processor_class(func, parent_class) # If it's a processor __call__ method or an image processor preprocess method and return_tensors is not already documented - if (is_processor_call or is_image_processor_preprocess) and "return_tensors" not in docstring: + if ( + is_processor_call or is_image_processor_preprocess + ) and "return_tensors" not in docstring: # Get the return_tensors documentation from ImageProcessorArgs source_args_dict = ( get_args_doc_from_source(ProcessorArgs) @@ -3830,12 +3955,19 @@ def _process_parameters_section( # Process **kwargs parameters if needed kwargs_docstring, kwargs_summary = _process_kwargs_parameters( - sig, func, parent_class, documented_kwargs, indent_level, undocumented_parameters + sig, + func, + parent_class, + documented_kwargs, + indent_level, + undocumented_parameters, ) docstring += kwargs_docstring # Add return_tensors for processor __call__ methods if not already present - docstring = _add_return_tensors_to_docstring(func, parent_class, docstring, indent_level) + docstring = _add_return_tensors_to_docstring( + func, parent_class, docstring, indent_level + ) # Add **kwargs summary line after return_tensors docstring += kwargs_summary @@ -3894,7 +4026,9 @@ def _prepare_return_docstring(output_type, config_class, add_intro=True): # Import here to avoid circular import from .doc import PT_RETURN_INTRODUCTION - intro = PT_RETURN_INTRODUCTION.format(full_output_type=full_output_type, config_class=config_class) + intro = PT_RETURN_INTRODUCTION.format( + full_output_type=full_output_type, config_class=config_class + ) else: intro = f"Returns:\n `{full_output_type}`" if documented_params: @@ -3935,7 +4069,9 @@ def _prepare_return_docstring(output_type, config_class, add_intro=True): # additional_info contains shape and optional status param_line = f"- **{param_name}** (`{param_type}`{additional_info}) -- {param_description}" else: - param_line = f"- **{param_name}** (`{param_type}`) -- {param_description}" + param_line = ( + f"- **{param_name}** (`{param_type}`) -- {param_description}" + ) # Handle multi-line descriptions: # Split the description to handle continuations with proper indentation @@ -3973,10 +4109,15 @@ def _process_returns_section(func_documentation, sig, config_class, indent_level return_docstring = "" # Extract returns section from existing docstring if available - if func_documentation is not None and (match_start := _re_return.search(func_documentation)) is not None: + if ( + func_documentation is not None + and (match_start := _re_return.search(func_documentation)) is not None + ): match_end = _re_example.search(func_documentation) if match_end: - return_docstring = func_documentation[match_start.start() : match_end.start()] + return_docstring = func_documentation[ + match_start.start() : match_end.start() + ] func_documentation = func_documentation[match_end.start() :] else: return_docstring = func_documentation[match_start.start() :] @@ -3985,7 +4126,9 @@ def _process_returns_section(func_documentation, sig, config_class, indent_level # Otherwise, generate return docstring from return annotation if available elif sig.return_annotation is not None and sig.return_annotation != inspect._empty: add_intro, return_annotation = contains_type(sig.return_annotation, ModelOutput) - return_docstring = _prepare_return_docstring(return_annotation, config_class, add_intro=add_intro) + return_docstring = _prepare_return_docstring( + return_annotation, config_class, add_intro=add_intro + ) # PT_RETURN_INTRODUCTION already starts with \n, so only add blank line if it doesn't start with one if not return_docstring.startswith("\n"): return_docstring = "\n" + return_docstring @@ -3995,7 +4138,14 @@ def _process_returns_section(func_documentation, sig, config_class, indent_level def _process_example_section( - func_documentation, func, parent_class, class_name, model_name_lowercase, config_class, checkpoint, indent_level + func_documentation, + func, + parent_class, + class_name, + model_name_lowercase, + config_class, + checkpoint, + indent_level, ): """ Process the example section of the docstring. @@ -4016,7 +4166,9 @@ def _process_example_section( example_docstring = "" # Use existing example section if available (with or without an "Example:" header) - if func_documentation is not None and (match := _re_example.search(func_documentation)): + if func_documentation is not None and ( + match := _re_example.search(func_documentation) + ): example_docstring = func_documentation[match.start() :] example_docstring = "\n" + set_min_indent(example_docstring, indent_level + 4) # Skip examples for processors @@ -4034,15 +4186,23 @@ def _process_example_section( # Get checkpoint example if (checkpoint_example := checkpoint) is None: try: - checkpoint_example = get_checkpoint_from_config_class(CONFIG_MAPPING[model_name_lowercase]) + checkpoint_example = get_checkpoint_from_config_class( + CONFIG_MAPPING[model_name_lowercase] + ) except KeyError: # For models with inconsistent lowercase model name if model_name_lowercase in HARDCODED_CONFIG_FOR_MODELS: - CONFIG_MAPPING_NAMES = auto_module.configuration_auto.CONFIG_MAPPING_NAMES - config_class_name = HARDCODED_CONFIG_FOR_MODELS[model_name_lowercase] + CONFIG_MAPPING_NAMES = ( + auto_module.configuration_auto.CONFIG_MAPPING_NAMES + ) + config_class_name = HARDCODED_CONFIG_FOR_MODELS[ + model_name_lowercase + ] if config_class_name in CONFIG_MAPPING_NAMES.values(): model_name_for_auto_config = [ - k for k, v in CONFIG_MAPPING_NAMES.items() if v == config_class_name + k + for k, v in CONFIG_MAPPING_NAMES.items() + if v == config_class_name ][0] if model_name_for_auto_config in CONFIG_MAPPING: checkpoint_example = get_checkpoint_from_config_class( @@ -4072,12 +4232,16 @@ def _process_example_section( # Check if the model is in a pipeline to get an example for name_model_list_for_task in MODELS_TO_PIPELINE: try: - model_list_for_task = getattr(auto_module.modeling_auto, name_model_list_for_task) + model_list_for_task = getattr( + auto_module.modeling_auto, name_model_list_for_task + ) except (ImportError, AttributeError): continue if class_name in model_list_for_task.values(): pipeline_name = MODELS_TO_PIPELINE[name_model_list_for_task] - example_annotation = PIPELINE_TASKS_TO_SAMPLE_DOCSTRINGS[pipeline_name].format( + example_annotation = PIPELINE_TASKS_TO_SAMPLE_DOCSTRINGS[ + pipeline_name + ].format( model_class=class_name, checkpoint=checkpoint_example, expected_output="...", @@ -4085,7 +4249,9 @@ def _process_example_section( qa_target_start_index=14, qa_target_end_index=15, ) - example_docstring = set_min_indent(example_annotation, indent_level + 4) + example_docstring = set_min_indent( + example_annotation, indent_level + 4 + ) break return example_docstring @@ -4106,14 +4272,21 @@ def auto_method_docstring( # Use inspect to retrieve the method's signature sig = inspect.signature(func) - indent_level = get_indent_level(func) if not parent_class else get_indent_level(parent_class) + indent_level = ( + get_indent_level(func) if not parent_class else get_indent_level(parent_class) + ) # Get model information model_name_lowercase, class_name, config_class = _get_model_info(func, parent_class) func_documentation = func.__doc__ if custom_args is not None and func_documentation is not None: - func_documentation = "\n" + set_min_indent(custom_args.strip("\n"), 0) + "\n" + func_documentation + func_documentation = ( + "\n" + + set_min_indent(custom_args.strip("\n"), 0) + + "\n" + + func_documentation + ) elif custom_args is not None: func_documentation = "\n" + set_min_indent(custom_args.strip("\n"), 0) @@ -4123,7 +4296,9 @@ def auto_method_docstring( if not docstring.strip().endswith("\n"): docstring += "\n" else: - docstring = add_intro_docstring(func, class_name=class_name, indent_level=indent_level) + docstring = add_intro_docstring( + func, class_name=class_name, indent_level=indent_level + ) # Process Parameters section docstring += _process_parameters_section( @@ -4180,7 +4355,10 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No docstring_args = "" if "PreTrainedModel" in (x.__name__ for x in cls.__mro__): docstring_init = auto_method_docstring( - cls.__init__, parent_class=cls, custom_args=custom_args, checkpoint=checkpoint + cls.__init__, + parent_class=cls, + custom_args=custom_args, + checkpoint=checkpoint, ).__doc__.replace("Args:", "Parameters:") elif "ProcessorMixin" in (x.__name__ for x in cls.__mro__): is_processor = True @@ -4189,7 +4367,9 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No parent_class=cls, custom_args=custom_args, checkpoint=checkpoint, - source_args_dict=get_args_doc_from_source([ModelArgs, ImageProcessorArgs, ProcessorArgs]), + source_args_dict=get_args_doc_from_source( + [ModelArgs, ImageProcessorArgs, ProcessorArgs] + ), ).__doc__.replace("Args:", "Parameters:") elif "ModelOutput" in (x.__name__ for x in cls.__mro__): # We have a data class @@ -4228,7 +4408,9 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No if ancestor.__name__ == "PreTrainedConfig": break own_config_params |= { - k for k, v in getattr(ancestor, "__annotations__", {}).items() if get_origin(v) is not ClassVar + k + for k, v in getattr(ancestor, "__annotations__", {}).items() + if get_origin(v) is not ClassVar } allowed_params = own_config_params if own_config_params else None docstring_init = auto_method_docstring( @@ -4242,8 +4424,14 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No indent_level = get_indent_level(cls) model_name_lowercase = get_model_name(cls) - model_name_title = " ".join([k.title() for k in model_name_lowercase.split("_")]) if model_name_lowercase else None - model_base_class = f"{model_name_title.title()}Model" if model_name_title is not None else None + model_name_title = ( + " ".join([k.title() for k in model_name_lowercase.split("_")]) + if model_name_lowercase + else None + ) + model_base_class = ( + f"{model_name_title.title()}Model" if model_name_title is not None else None + ) if model_name_lowercase is not None: try: model_base_class = getattr( @@ -4264,16 +4452,31 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No name = re.findall(rf"({'|'.join(ClassDocstring.__dict__.keys())})$", cls.__name__) - if name == [] and custom_intro is None and not is_dataclass and not is_processor and not is_image_processor: + if ( + name == [] + and custom_intro is None + and not is_dataclass + and not is_processor + and not is_image_processor + ): raise ValueError( f"`{cls.__name__}` is not registered in the auto doc. Here are the available classes: {ClassDocstring.__dict__.keys()}.\n" "Add a `custom_intro` to the decorator if you want to use `auto_docstring` on a class not registered in the auto doc." ) - if name != [] or custom_intro is not None or is_config or is_dataclass or is_processor or is_image_processor: + if ( + name != [] + or custom_intro is not None + or is_config + or is_dataclass + or is_processor + or is_image_processor + ): name = name[0] if name else None formatting_kwargs = {"model_name": model_name_title} if name == "Config": - formatting_kwargs.update({"model_base_class": model_base_class, "model_checkpoint": checkpoint}) + formatting_kwargs.update( + {"model_base_class": model_base_class, "model_checkpoint": checkpoint} + ) if custom_intro is not None: pre_block = equalize_indent(custom_intro, indent_level) if not pre_block.endswith("\n"): @@ -4294,9 +4497,15 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No else: pre_block = getattr(ClassDocstring, name).format(**formatting_kwargs) # Start building the docstring - docstring = set_min_indent(f"{pre_block}", indent_level) if len(pre_block) else "" - if name != "PreTrainedModel" and "PreTrainedModel" in (x.__name__ for x in cls.__mro__): - docstring += set_min_indent(f"{ClassDocstring.PreTrainedModel}", indent_level) + docstring = ( + set_min_indent(f"{pre_block}", indent_level) if len(pre_block) else "" + ) + if name != "PreTrainedModel" and "PreTrainedModel" in ( + x.__name__ for x in cls.__mro__ + ): + docstring += set_min_indent( + f"{ClassDocstring.PreTrainedModel}", indent_level + ) # Add the __init__ docstring if docstring_init: docstring += set_min_indent(f"\n{docstring_init}", indent_level) @@ -4307,15 +4516,30 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No doc_class = cls.__doc__ if cls.__doc__ else "" documented_kwargs = parse_docstring(doc_class)[0] for param_name, param_type_annotation in cls.__annotations__.items(): - param_type, optional = process_type_annotation(param_type_annotation, param_name) + param_type, optional = process_type_annotation( + param_type_annotation, param_name + ) # Check for default value param_default = "" param_default = str(getattr(cls, param_name, "")) - param_default = f", defaults to `{param_default}`" if param_default != "" else "" + param_default = ( + f", defaults to `{param_default}`" if param_default != "" else "" + ) - param_type, optional_string, shape_string, additional_info, description, is_documented = ( - _get_parameter_info(param_name, documented_kwargs, source_args_dict, param_type, optional) + ( + param_type, + optional_string, + shape_string, + additional_info, + description, + is_documented, + ) = _get_parameter_info( + param_name, + documented_kwargs, + source_args_dict, + param_type, + optional, ) if is_documented: @@ -4498,10 +4722,18 @@ class MyModelOutput(ImageClassifierOutput): def auto_docstring_decorator(obj): if len(obj.__qualname__.split(".")) > 1: return auto_method_docstring( - obj, custom_args=custom_args, custom_intro=custom_intro, checkpoint=checkpoint + obj, + custom_args=custom_args, + custom_intro=custom_intro, + checkpoint=checkpoint, ) else: - return auto_class_docstring(obj, custom_args=custom_args, custom_intro=custom_intro, checkpoint=checkpoint) + return auto_class_docstring( + obj, + custom_args=custom_args, + custom_intro=custom_intro, + checkpoint=checkpoint, + ) if obj: return auto_docstring_decorator(obj) diff --git a/test_future_annotations.py b/test_future_annotations.py new file mode 100644 index 000000000000..d0dc5574ece9 --- /dev/null +++ b/test_future_annotations.py @@ -0,0 +1,18 @@ +from __future__ import annotations +from transformers.utils.auto_docstring import _process_kwargs_parameters +import inspect + + +def test_with_future_annotations(): + # This should fail without fix + def dummy_func(**kwargs: "ImagesKwargs"): + pass + + sig = inspect.signature(dummy_func) + # This line should trigger the bug + result = _process_kwargs_parameters(sig, dummy_func, None, {}, 0, []) + print("Success!") + + +if __name__ == "__main__": + test_with_future_annotations()