diff --git a/src/diffusers/pipelines/prx/__init__.py b/src/diffusers/pipelines/prx/__init__.py index 87aaefbd1368..ad2948e92e04 100644 --- a/src/diffusers/pipelines/prx/__init__.py +++ b/src/diffusers/pipelines/prx/__init__.py @@ -24,14 +24,25 @@ else: _import_structure["pipeline_prx"] = ["PRXPipeline"] -# Import T5GemmaEncoder for pipeline loading compatibility +# Wrap T5GemmaEncoder to pass config.encoder (T5GemmaModuleConfig) instead of the +# composite T5GemmaConfig, which lacks flat attributes expected by T5GemmaEncoder.__init__. try: if is_transformers_available(): import transformers - from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder as _T5GemmaEncoder + + class T5GemmaEncoder(_T5GemmaEncoder): + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): + if "config" not in kwargs: + from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig + + config = T5GemmaConfig.from_pretrained(pretrained_model_name_or_path) + if hasattr(config, "encoder"): + kwargs["config"] = config.encoder + return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) _additional_imports["T5GemmaEncoder"] = T5GemmaEncoder - # Patch transformers module directly for serialization if not hasattr(transformers, "T5GemmaEncoder"): transformers.T5GemmaEncoder = T5GemmaEncoder except ImportError: