diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 8868e942ce3d..e279a4f7ddc9 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -801,13 +801,7 @@ def load_sub_model( # add kwargs to loading method diffusers_module = importlib.import_module(__name__.split(".")[0]) loading_kwargs = {} - if issubclass(class_obj, torch.nn.Module): - loading_kwargs["torch_dtype"] = torch_dtype - if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): - loading_kwargs["provider"] = provider - loading_kwargs["sess_options"] = sess_options - loading_kwargs["provider_options"] = provider_options - + is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) if is_transformers_available(): @@ -820,6 +814,17 @@ def load_sub_model( and issubclass(class_obj, PreTrainedModel) and transformers_version >= version.parse("4.20.0") ) + + # For transformers models, use 'dtype' instead of 'torch_dtype' to avoid deprecation warnings + if issubclass(class_obj, torch.nn.Module): + if is_transformers_model: + loading_kwargs["dtype"] = torch_dtype + else: + loading_kwargs["torch_dtype"] = torch_dtype + if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): + loading_kwargs["provider"] = provider + loading_kwargs["sess_options"] = sess_options + loading_kwargs["provider_options"] = provider_options # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.