From ba9d2490bf7417ac9a9307b7e94d6edcafc6b4a1 Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 26 Feb 2026 12:31:48 +0530 Subject: [PATCH 1/2] update --- docs/source/en/using-diffusers/automodel.md | 27 +++++++++++++++ src/diffusers/configuration_utils.py | 38 +++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/docs/source/en/using-diffusers/automodel.md b/docs/source/en/using-diffusers/automodel.md index d8cea79a10c9..82d4d14a10a9 100644 --- a/docs/source/en/using-diffusers/automodel.md +++ b/docs/source/en/using-diffusers/automodel.md @@ -97,5 +97,32 @@ If the custom model inherits from the [`ModelMixin`] class, it gets access to th > ) > ``` +### Saving custom models + +Use [`~ConfigMixin.register_for_auto_class`] to add the `auto_map` entry to `config.json` automatically when saving. This avoids having to manually edit the config file. + +```py +# my_model.py +from diffusers import ModelMixin, ConfigMixin + +class MyCustomModel(ModelMixin, ConfigMixin): + ... + +MyCustomModel.register_for_auto_class("AutoModel") + +model = MyCustomModel(...) +model.save_pretrained("./my_model") +``` + +The saved `config.json` will include the `auto_map` field. + +```json +{ + "auto_map": { + "AutoModel": "my_model.MyCustomModel" + } +} +``` + > [!NOTE] > Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide. \ No newline at end of file diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 2545c9db35b2..7a95ce20aaff 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -107,6 +107,38 @@ class ConfigMixin: has_compatibles = False _deprecated_kwargs = [] + _auto_class = None + + @classmethod + def register_for_auto_class(cls, auto_class="AutoModel"): + """ + Register this class with the given auto class so that it can be loaded with `AutoModel.from_pretrained(..., + trust_remote_code=True)`. + + When the config is saved, the resulting `config.json` will include an `auto_map` entry mapping the auto class + to this class's module and class name. + + Args: + auto_class (`str` or type, *optional*, defaults to `"AutoModel"`): + The auto class to register this class with. Can be a string (e.g. `"AutoModel"`) or the class itself. + Currently only `"AutoModel"` is supported. + + Example: + + ```python + from diffusers import ModelMixin, ConfigMixin + + + class MyCustomModel(ModelMixin, ConfigMixin): ... + + + MyCustomModel.register_for_auto_class("AutoModel") + ``` + """ + if auto_class != "AutoModel": + raise ValueError(f"Only 'AutoModel' is supported, got '{auto_class}'.") + + cls._auto_class = auto_class def register_to_config(self, **kwargs): if self.config_name is None: @@ -621,6 +653,12 @@ def to_json_saveable(value): # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. _ = config_dict.pop("_pre_quantization_dtype", None) + if getattr(self, "_auto_class", None) is not None: + module = self.__class__.__module__.split(".")[-1] + auto_map = config_dict.get("auto_map", {}) + auto_map[self._auto_class] = f"{module}.{self.__class__.__name__}" + config_dict["auto_map"] = auto_map + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" def to_json_file(self, json_file_path: str | os.PathLike): From f20e0f4e4b6c56f2e998b218482baf721d5b8364 Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 26 Feb 2026 12:42:11 +0530 Subject: [PATCH 2/2] update --- tests/models/test_models_auto.py | 53 ++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tests/models/test_models_auto.py b/tests/models/test_models_auto.py index 3506f8fa0d17..b88647459f3a 100644 --- a/tests/models/test_models_auto.py +++ b/tests/models/test_models_auto.py @@ -1,9 +1,14 @@ +import json +import os +import tempfile import unittest from unittest.mock import MagicMock, patch from transformers import CLIPTextModel, LongformerModel +from diffusers import ConfigMixin from diffusers.models import AutoModel, UNet2DConditionModel +from diffusers.models.modeling_utils import ModelMixin class TestAutoModel(unittest.TestCase): @@ -100,3 +105,51 @@ def test_from_config_with_model_type_routes_to_transformers(self, mock_get_class def test_from_config_raises_on_none(self): with self.assertRaises(ValueError, msg="Please provide a `pretrained_model_name_or_path_or_dict`"): AutoModel.from_config(None) + + +class TestRegisterForAutoClass(unittest.TestCase): + def test_register_for_auto_class_sets_attribute(self): + class DummyModel(ModelMixin, ConfigMixin): + config_name = "config.json" + + DummyModel.register_for_auto_class("AutoModel") + self.assertEqual(DummyModel._auto_class, "AutoModel") + + def test_register_for_auto_class_rejects_unsupported(self): + class DummyModel(ModelMixin, ConfigMixin): + config_name = "config.json" + + with self.assertRaises(ValueError, msg="Only 'AutoModel' is supported"): + DummyModel.register_for_auto_class("AutoPipeline") + + def test_auto_map_in_saved_config(self): + class DummyModel(ModelMixin, ConfigMixin): + config_name = "config.json" + + DummyModel.register_for_auto_class("AutoModel") + model = DummyModel() + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_config(tmpdir) + config_path = os.path.join(tmpdir, "config.json") + with open(config_path, "r") as f: + config = json.load(f) + + self.assertIn("auto_map", config) + self.assertIn("AutoModel", config["auto_map"]) + module_name = DummyModel.__module__.split(".")[-1] + self.assertEqual(config["auto_map"]["AutoModel"], f"{module_name}.DummyModel") + + def test_no_auto_map_without_register(self): + class DummyModel(ModelMixin, ConfigMixin): + config_name = "config.json" + + model = DummyModel() + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_config(tmpdir) + config_path = os.path.join(tmpdir, "config.json") + with open(config_path, "r") as f: + config = json.load(f) + + self.assertNotIn("auto_map", config)