diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index a8781c8042a6..c95609599089 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -543,7 +543,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[s if kwargs.get("dtype") == "auto": _ = kwargs.pop("dtype") # to not overwrite the quantization_config if config has a quantization_config - if kwargs.get("quantization_config") is not None: + if "quantization_config" in kwargs: _ = kwargs.pop("quantization_config") config, kwargs = AutoConfig.from_pretrained( @@ -560,7 +560,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[s kwargs["torch_dtype"] = "auto" if kwargs_orig.get("dtype", None) == "auto": kwargs["dtype"] = "auto" - if kwargs_orig.get("quantization_config", None) is not None: + if "quantization_config" in kwargs_orig: kwargs["quantization_config"] = kwargs_orig["quantization_config"] has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map