Skip to content

Commit 9ab1487

Browse files
wrapper pipelines + make style
1 parent 62c682e commit 9ab1487

File tree

8 files changed

+557
-42
lines changed

8 files changed

+557
-42
lines changed

scripts/convert_cosmos_to_diffusers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
EDMEulerScheduler,
6363
FlowMatchEulerDiscreteScheduler,
6464
)
65-
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos_2_5_PredictBase
65+
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBase
6666

6767

6868
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -527,7 +527,7 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
527527
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
528528

529529

530-
def save_pipeline_cosmos_2_5(args, transformer, vae):
530+
def save_pipeline_cosmos2_5(args, transformer, vae):
531531
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
532532
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
533533

@@ -538,7 +538,7 @@ def save_pipeline_cosmos_2_5(args, transformer, vae):
538538

539539
scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
540540

541-
pipe = Cosmos_2_5_PredictBase(
541+
pipe = Cosmos2_5_PredictBase(
542542
text_encoder=text_encoder,
543543
tokenizer=tokenizer,
544544
transformer=transformer,
@@ -613,6 +613,6 @@ def get_args():
613613
assert args.tokenizer_path is not None
614614
save_pipeline_cosmos_2_0(args, transformer, vae)
615615
elif "Cosmos-2.5" in args.transformer_type:
616-
save_pipeline_cosmos_2_5(args, transformer, vae)
616+
save_pipeline_cosmos2_5(args, transformer, vae)
617617
else:
618618
raise AssertionError(f"{args.transformer_type} not supported")

src/diffusers/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,10 @@
463463
"CogView4ControlPipeline",
464464
"CogView4Pipeline",
465465
"ConsisIDPipeline",
466-
"Cosmos_2_5_PredictBase",
466+
"Cosmos2_5_PredictBase",
467+
"Cosmos2_5_PredictImage2World",
468+
"Cosmos2_5_PredictText2World",
469+
"Cosmos2_5_PredictVideo2World",
467470
"Cosmos2TextToImagePipeline",
468471
"Cosmos2VideoToWorldPipeline",
469472
"CosmosTextToWorldPipeline",
@@ -1173,9 +1176,12 @@
11731176
CogView4ControlPipeline,
11741177
CogView4Pipeline,
11751178
ConsisIDPipeline,
1179+
Cosmos2_5_PredictBase,
1180+
Cosmos2_5_PredictImage2World,
1181+
Cosmos2_5_PredictText2World,
1182+
Cosmos2_5_PredictVideo2World,
11761183
Cosmos2TextToImagePipeline,
11771184
Cosmos2VideoToWorldPipeline,
1178-
Cosmos_2_5_PredictBase,
11791185
CosmosTextToWorldPipeline,
11801186
CosmosVideoToWorldPipeline,
11811187
CycleDiffusionPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,10 @@
165165
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
166166
_import_structure["consisid"] = ["ConsisIDPipeline"]
167167
_import_structure["cosmos"] = [
168-
"Cosmos_2_5_PredictBase",
168+
"Cosmos2_5_PredictBase",
169+
"Cosmos2_5_PredictImage2World",
170+
"Cosmos2_5_PredictText2World",
171+
"Cosmos2_5_PredictVideo2World",
169172
"Cosmos2TextToImagePipeline",
170173
"CosmosTextToWorldPipeline",
171174
"CosmosVideoToWorldPipeline",
@@ -617,9 +620,12 @@
617620
StableDiffusionXLControlNetXSPipeline,
618621
)
619622
from .cosmos import (
623+
Cosmos2_5_PredictBase,
624+
Cosmos2_5_PredictImage2World,
625+
Cosmos2_5_PredictText2World,
626+
Cosmos2_5_PredictVideo2World,
620627
Cosmos2TextToImagePipeline,
621628
Cosmos2VideoToWorldPipeline,
622-
Cosmos_2_5_PredictBase,
623629
CosmosTextToWorldPipeline,
624630
CosmosVideoToWorldPipeline,
625631
)

src/diffusers/pipelines/cosmos/__init__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222

2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
25-
_import_structure["pipeline_cosmos2_5_predict"] = ["Cosmos_2_5_PredictBase", "retrieve_latents"]
25+
_import_structure["pipeline_cosmos2_5_predict"] = [
26+
"Cosmos2_5_PredictBase",
27+
"Cosmos2_5_PredictImage2World",
28+
"Cosmos2_5_PredictText2World",
29+
"Cosmos2_5_PredictVideo2World",
30+
]
2631
_import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"]
2732
_import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
2833
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
@@ -36,7 +41,12 @@
3641
except OptionalDependencyNotAvailable:
3742
from ...utils.dummy_torch_and_transformers_objects import *
3843
else:
39-
from .pipeline_cosmos2_5_predict import Cosmos_2_5_PredictBase, retrieve_latents
44+
from .pipeline_cosmos2_5_predict import (
45+
Cosmos2_5_PredictBase,
46+
Cosmos2_5_PredictImage2World,
47+
Cosmos2_5_PredictText2World,
48+
Cosmos2_5_PredictVideo2World,
49+
)
4050
from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline
4151
from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
4252
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline

0 commit comments

Comments
 (0)