From 542ab9604b0f2759d495301bb7595ad8a28da6f9 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Fri, 13 Feb 2026 19:14:02 +0000 Subject: [PATCH] Fix T5GemmaEncoder loading for transformers 5.x composite T5GemmaConfig --- src/diffusers/pipelines/prx/__init__.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/prx/__init__.py b/src/diffusers/pipelines/prx/__init__.py index 87aaefbd1368..ad2948e92e04 100644 --- a/src/diffusers/pipelines/prx/__init__.py +++ b/src/diffusers/pipelines/prx/__init__.py @@ -24,14 +24,25 @@ else: _import_structure["pipeline_prx"] = ["PRXPipeline"] -# Import T5GemmaEncoder for pipeline loading compatibility +# Wrap T5GemmaEncoder to pass config.encoder (T5GemmaModuleConfig) instead of the +# composite T5GemmaConfig, which lacks flat attributes expected by T5GemmaEncoder.__init__. try: if is_transformers_available(): import transformers - from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder as _T5GemmaEncoder + + class T5GemmaEncoder(_T5GemmaEncoder): + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): + if "config" not in kwargs: + from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig + + config = T5GemmaConfig.from_pretrained(pretrained_model_name_or_path) + if hasattr(config, "encoder"): + kwargs["config"] = config.encoder + return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) _additional_imports["T5GemmaEncoder"] = T5GemmaEncoder - # Patch transformers module directly for serialization if not hasattr(transformers, "T5GemmaEncoder"): transformers.T5GemmaEncoder = T5GemmaEncoder except ImportError: