From 9c1c6e6e226e80bf03e65c2c93d66a215c4b3987 Mon Sep 17 00:00:00 2001 From: DN6 Date: Sun, 8 Feb 2026 13:41:12 +0530 Subject: [PATCH 1/2] update --- src/diffusers/models/auto_model.py | 124 ++++++++++++++++++++++++++++- tests/models/test_models_auto.py | 68 +++++++++++++++- 2 files changed, 190 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index 0a5b7fff1c66..734ef7763f6f 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -31,10 +31,132 @@ class AutoModel(ConfigMixin): def __init__(self, *args, **kwargs): raise EnvironmentError( f"{self.__class__.__name__} is designed to be instantiated " - f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)`, " + f"`{self.__class__.__name__}.from_config(config)`, or " f"`{self.__class__.__name__}.from_pipe(pipeline)` methods." ) + @classmethod + def from_config( + cls, pretrained_model_name_or_path_or_dict: Optional[Union[str, os.PathLike, dict]] = None, **kwargs + ): + r""" + Instantiate a model from a config dictionary or a pretrained model configuration file with random weights (no + pretrained weights are loaded). + + Parameters: + pretrained_model_name_or_path_or_dict (`str`, `os.PathLike`, or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model + configuration hosted on the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing a model configuration + file. + - A config dictionary. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model configuration, overriding the cached version if + it exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model configuration files or not. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether to trust remote code. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + Returns: + A model object instantiated from the config with random weights. + + Example: + + ```py + from diffusers import AutoModel + + model = AutoModel.from_config("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet") + ``` + """ + subfolder = kwargs.pop("subfolder", None) + trust_remote_code = kwargs.pop("trust_remote_code", False) + + hub_kwargs_names = [ + "cache_dir", + "force_download", + "local_files_only", + "proxies", + "revision", + "token", + ] + hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names} + + if pretrained_model_name_or_path_or_dict is None: + raise ValueError( + "Please provide a `pretrained_model_name_or_path_or_dict` as the first positional argument." + ) + + if isinstance(pretrained_model_name_or_path_or_dict, (str, os.PathLike)): + pretrained_model_name_or_path = pretrained_model_name_or_path_or_dict + cls.config_name = "config.json" + config = cls.load_config(pretrained_model_name_or_path, subfolder=subfolder, **hub_kwargs) + else: + config = pretrained_model_name_or_path_or_dict + pretrained_model_name_or_path = config.get("_name_or_path") + + library = None + orig_class_name = None + + if "_class_name" in config: + orig_class_name = config["_class_name"] + library = "diffusers" + elif "model_type" in config: + orig_class_name = "AutoModel" + library = "transformers" + else: + raise ValueError( + f"Couldn't find a model class associated with the config: {config}. Make sure the config " + "contains a `_class_name` or `model_type` key." + ) + + has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"] + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_remote_code + ) + + if has_remote_code and trust_remote_code: + class_ref = config["auto_map"][cls.__name__] + module_file, class_name = class_ref.split(".") + module_file = module_file + ".py" + model_cls = get_class_from_dynamic_module( + pretrained_model_name_or_path, + subfolder=subfolder, + module_file=module_file, + class_name=class_name, + **hub_kwargs, + ) + else: + from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates + + model_cls, _ = get_class_obj_and_candidates( + library_name=library, + class_name=orig_class_name, + importable_classes=ALL_IMPORTABLE_CLASSES, + pipelines=None, + is_pipeline_module=False, + ) + + if model_cls is None: + raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.") + + return model_cls.from_config(config, **kwargs) + @classmethod @validate_hf_hub_args def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLike]] = None, **kwargs): diff --git a/tests/models/test_models_auto.py b/tests/models/test_models_auto.py index a70754343f30..66030d040ea0 100644 --- a/tests/models/test_models_auto.py +++ b/tests/models/test_models_auto.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import patch +from unittest.mock import MagicMock, patch from transformers import CLIPTextModel, LongformerModel @@ -30,3 +30,69 @@ def test_load_from_config_without_subfolder(self): def test_load_from_model_index(self): model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder") assert isinstance(model, CLIPTextModel) + + +class TestAutoModelFromConfig(unittest.TestCase): + @patch( + "diffusers.pipelines.pipeline_loading_utils.get_class_obj_and_candidates", + return_value=(MagicMock(), None), + ) + def test_from_config_with_dict_diffusers_class(self, mock_get_class): + config = {"_class_name": "UNet2DConditionModel", "sample_size": 64} + mock_model = MagicMock() + mock_get_class.return_value[0].from_config.return_value = mock_model + + result = AutoModel.from_config(config) + + mock_get_class.assert_called_once_with( + library_name="diffusers", + class_name="UNet2DConditionModel", + importable_classes=unittest.mock.ANY, + pipelines=None, + is_pipeline_module=False, + ) + mock_get_class.return_value[0].from_config.assert_called_once_with(config) + assert result is mock_model + + @patch( + "diffusers.pipelines.pipeline_loading_utils.get_class_obj_and_candidates", + return_value=(MagicMock(), None), + ) + @patch("diffusers.models.AutoModel.load_config", return_value={"_class_name": "UNet2DConditionModel"}) + def test_from_config_with_string_path(self, mock_load_config, mock_get_class): + mock_model = MagicMock() + mock_get_class.return_value[0].from_config.return_value = mock_model + + result = AutoModel.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet") + + mock_load_config.assert_called_once() + assert result is mock_model + + def test_from_config_raises_on_missing_class_info(self): + config = {"some_key": "some_value"} + with self.assertRaises(ValueError, msg="Couldn't find a model class"): + AutoModel.from_config(config) + + @patch( + "diffusers.pipelines.pipeline_loading_utils.get_class_obj_and_candidates", + return_value=(MagicMock(), None), + ) + def test_from_config_with_model_type_routes_to_transformers(self, mock_get_class): + config = {"model_type": "clip_text_model"} + mock_model = MagicMock() + mock_get_class.return_value[0].from_config.return_value = mock_model + + result = AutoModel.from_config(config) + + mock_get_class.assert_called_once_with( + library_name="transformers", + class_name="AutoModel", + importable_classes=unittest.mock.ANY, + pipelines=None, + is_pipeline_module=False, + ) + assert result is mock_model + + def test_from_config_raises_on_none(self): + with self.assertRaises(ValueError, msg="Please provide a `pretrained_model_name_or_path_or_dict`"): + AutoModel.from_config(None) From 20751a522190e5b56801270addac4101a1bd8386 Mon Sep 17 00:00:00 2001 From: DN6 Date: Wed, 11 Feb 2026 15:35:01 +0530 Subject: [PATCH 2/2] update --- src/diffusers/models/auto_model.py | 34 +++++++++++++----------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index 734ef7763f6f..edddb45b775b 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -104,26 +104,10 @@ def from_config( if isinstance(pretrained_model_name_or_path_or_dict, (str, os.PathLike)): pretrained_model_name_or_path = pretrained_model_name_or_path_or_dict - cls.config_name = "config.json" config = cls.load_config(pretrained_model_name_or_path, subfolder=subfolder, **hub_kwargs) else: config = pretrained_model_name_or_path_or_dict - pretrained_model_name_or_path = config.get("_name_or_path") - - library = None - orig_class_name = None - - if "_class_name" in config: - orig_class_name = config["_class_name"] - library = "diffusers" - elif "model_type" in config: - orig_class_name = "AutoModel" - library = "transformers" - else: - raise ValueError( - f"Couldn't find a model class associated with the config: {config}. Make sure the config " - "contains a `_class_name` or `model_type` key." - ) + pretrained_model_name_or_path = config.get("_name_or_path", None) has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"] trust_remote_code = resolve_trust_remote_code( @@ -142,18 +126,30 @@ def from_config( **hub_kwargs, ) else: + if "_class_name" in config: + class_name = config["_class_name"] + library = "diffusers" + elif "model_type" in config: + class_name = "AutoModel" + library = "transformers" + else: + raise ValueError( + f"Couldn't find a model class associated with the config: {config}. Make sure the config " + "contains a `_class_name` or `model_type` key." + ) + from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates model_cls, _ = get_class_obj_and_candidates( library_name=library, - class_name=orig_class_name, + class_name=class_name, importable_classes=ALL_IMPORTABLE_CLASSES, pipelines=None, is_pipeline_module=False, ) if model_cls is None: - raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.") + raise ValueError(f"AutoModel can't find a model linked to {class_name}.") return model_cls.from_config(config, **kwargs)