Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 119 additions & 1 deletion src/diffusers/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,128 @@ class AutoModel(ConfigMixin):
def __init__(self, *args, **kwargs):
raise EnvironmentError(
f"{self.__class__.__name__} is designed to be instantiated "
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)`, "
f"`{self.__class__.__name__}.from_config(config)`, or "
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
)

@classmethod
def from_config(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh I think maybe we should support this in from_pretrained()? even for scheduler and such weightless things, we load them using from_pretrained(), no?

we normally use from_config like this, would be confusing I think

pipe = DiffusionPipeline.from_pretrained(..)
pipe.scheduler = NEWSCHEDULER.from_config(pipe.config)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is true. But the SchedulerMixin's from_pretrained just runs from_config under the hood.

config, kwargs, commit_hash = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder,
return_unused_kwargs=True,
return_commit_hash=True,
**kwargs,
)
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)

I thought perhaps we could make an exception for just the AutoModel class? The load_config and then from_config flow doesn't work too well when trying to load custom code.

Copy link
Collaborator

@yiyixuxu yiyixuxu Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh thanks for explaining. I think I got it now.
but is it possible to not re-use the from_config API with a different usage pattern? I think maybe we can add a new arg to from_pretrained to support this use case? e.g load_weights=True, init_from_config=False etc

Even for custom components that come with the weights, we neeed to support for this use case, for example training, or stuff like here we need to initialize it to inspect model structure for quantisation https://github.com/cubiq/Mellon/blob/main/modules/ModularDiffusers/loaders.py#L181

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or maybe a brand new method if it you prefer

Copy link
Collaborator

@yiyixuxu yiyixuxu Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init_from_hub? from_pretrained_config? etc

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO from_config is the preferred option because it works well with ComponentSpec's default_creation_method as well. So we don't need to introduce too many changes to modular get this functionality.

Models currently support passing in a pretrained_model_name_or_path. I know that there was a plan to deprecate this behaviour (but I'm not clear as to why the deprecation is needed?)

deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
if "Scheduler" in cls.__name__:
deprecation_message += (
f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
" Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
" be removed in v1.0.0."
)
elif "Model" in cls.__name__:
deprecation_message += (
f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
" instead. This functionality will be removed in v1.0.0."
)
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but I'm not clear as to why the deprecation is needed?

I think it's before my time in but if I have to guess, they deprecated it after adding from_pretrained() to scheduler to make API cleaner: from_pretrained handles loading (without or without weights) and from_config takes only config dict

so if we undo the deprecation & introduce the from_config on AutoModel, the mental model would be something like this?

  • from_pretrained create the model with the hub_id with default behavior: load weigths for models, create config for weightless things) and this correspond so the default_creation_method in ComponentSpec
  • from_config is used to explicitly create from config without weights, acccept both a config dict and path

this sounds fine to me

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO from_config is the preferred option because it works well with ComponentSpec's default_creation_method as wel

One note on this though — default_creation_method is currently aligned with the current from_pretrained/from_config usage. e.g. the default creation method would be from_pretrained for weightless components as long as it needs to load a config from hub, image processor uses from_config because it's created from a config dict defined in componenet spec directly

not sure if this complicates anything so just something to keep in mind

cls, pretrained_model_name_or_path_or_dict: Optional[Union[str, os.PathLike, dict]] = None, **kwargs
):
r"""
Instantiate a model from a config dictionary or a pretrained model configuration file with random weights (no
pretrained weights are loaded).

Parameters:
pretrained_model_name_or_path_or_dict (`str`, `os.PathLike`, or `dict`):
Can be either:

- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model
configuration hosted on the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing a model configuration
file.
- A config dictionary.

cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model configuration, overriding the cached version if
it exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether to only load local model configuration files or not.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use.
trust_remote_code (`bool`, *optional*, defaults to `False`):
Whether to trust remote code.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.

Returns:
A model object instantiated from the config with random weights.

Example:

```py
from diffusers import AutoModel

model = AutoModel.from_config("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
```
"""
subfolder = kwargs.pop("subfolder", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)

hub_kwargs_names = [
"cache_dir",
"force_download",
"local_files_only",
"proxies",
"revision",
"token",
]
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}

if pretrained_model_name_or_path_or_dict is None:
raise ValueError(
"Please provide a `pretrained_model_name_or_path_or_dict` as the first positional argument."
)

if isinstance(pretrained_model_name_or_path_or_dict, (str, os.PathLike)):
pretrained_model_name_or_path = pretrained_model_name_or_path_or_dict
config = cls.load_config(pretrained_model_name_or_path, subfolder=subfolder, **hub_kwargs)
else:
config = pretrained_model_name_or_path_or_dict
pretrained_model_name_or_path = config.get("_name_or_path", None)

has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)

if has_remote_code and trust_remote_code:
class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
module_file = module_file + ".py"
model_cls = get_class_from_dynamic_module(
pretrained_model_name_or_path,
subfolder=subfolder,
module_file=module_file,
class_name=class_name,
**hub_kwargs,
)
else:
if "_class_name" in config:
class_name = config["_class_name"]
library = "diffusers"
elif "model_type" in config:
class_name = "AutoModel"
library = "transformers"
else:
raise ValueError(
f"Couldn't find a model class associated with the config: {config}. Make sure the config "
"contains a `_class_name` or `model_type` key."
)

from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates

model_cls, _ = get_class_obj_and_candidates(
library_name=library,
class_name=class_name,
importable_classes=ALL_IMPORTABLE_CLASSES,
pipelines=None,
is_pipeline_module=False,
)

if model_cls is None:
raise ValueError(f"AutoModel can't find a model linked to {class_name}.")

return model_cls.from_config(config, **kwargs)

@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLike]] = None, **kwargs):
Expand Down
68 changes: 67 additions & 1 deletion tests/models/test_models_auto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from unittest.mock import patch
from unittest.mock import MagicMock, patch

from transformers import CLIPTextModel, LongformerModel

Expand Down Expand Up @@ -30,3 +30,69 @@ def test_load_from_config_without_subfolder(self):
def test_load_from_model_index(self):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
assert isinstance(model, CLIPTextModel)


class TestAutoModelFromConfig(unittest.TestCase):
@patch(
"diffusers.pipelines.pipeline_loading_utils.get_class_obj_and_candidates",
return_value=(MagicMock(), None),
)
def test_from_config_with_dict_diffusers_class(self, mock_get_class):
config = {"_class_name": "UNet2DConditionModel", "sample_size": 64}
mock_model = MagicMock()
mock_get_class.return_value[0].from_config.return_value = mock_model

result = AutoModel.from_config(config)

mock_get_class.assert_called_once_with(
library_name="diffusers",
class_name="UNet2DConditionModel",
importable_classes=unittest.mock.ANY,
pipelines=None,
is_pipeline_module=False,
)
mock_get_class.return_value[0].from_config.assert_called_once_with(config)
assert result is mock_model

@patch(
"diffusers.pipelines.pipeline_loading_utils.get_class_obj_and_candidates",
return_value=(MagicMock(), None),
)
@patch("diffusers.models.AutoModel.load_config", return_value={"_class_name": "UNet2DConditionModel"})
def test_from_config_with_string_path(self, mock_load_config, mock_get_class):
mock_model = MagicMock()
mock_get_class.return_value[0].from_config.return_value = mock_model

result = AutoModel.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet")

mock_load_config.assert_called_once()
assert result is mock_model

def test_from_config_raises_on_missing_class_info(self):
config = {"some_key": "some_value"}
with self.assertRaises(ValueError, msg="Couldn't find a model class"):
AutoModel.from_config(config)

@patch(
"diffusers.pipelines.pipeline_loading_utils.get_class_obj_and_candidates",
return_value=(MagicMock(), None),
)
def test_from_config_with_model_type_routes_to_transformers(self, mock_get_class):
config = {"model_type": "clip_text_model"}
mock_model = MagicMock()
mock_get_class.return_value[0].from_config.return_value = mock_model

result = AutoModel.from_config(config)

mock_get_class.assert_called_once_with(
library_name="transformers",
class_name="AutoModel",
importable_classes=unittest.mock.ANY,
pipelines=None,
is_pipeline_module=False,
)
assert result is mock_model

def test_from_config_raises_on_none(self):
with self.assertRaises(ValueError, msg="Please provide a `pretrained_model_name_or_path_or_dict`"):
AutoModel.from_config(None)
Loading