Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/diffusers/pipelines/prx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,25 @@
else:
_import_structure["pipeline_prx"] = ["PRXPipeline"]

# Import T5GemmaEncoder for pipeline loading compatibility
# Wrap T5GemmaEncoder to pass config.encoder (T5GemmaModuleConfig) instead of the
# composite T5GemmaConfig, which lacks flat attributes expected by T5GemmaEncoder.__init__.
try:
if is_transformers_available():
import transformers
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder as _T5GemmaEncoder

class T5GemmaEncoder(_T5GemmaEncoder):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
if "config" not in kwargs:
from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig

config = T5GemmaConfig.from_pretrained(pretrained_model_name_or_path)
if hasattr(config, "encoder"):
kwargs["config"] = config.encoder
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)

_additional_imports["T5GemmaEncoder"] = T5GemmaEncoder
# Patch transformers module directly for serialization
if not hasattr(transformers, "T5GemmaEncoder"):
transformers.T5GemmaEncoder = T5GemmaEncoder
except ImportError:
Expand Down