Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/api/pipelines/cosmos.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ output.save("output.png")
- all
- __call__

## Cosmos2_5_PredictBasePipeline

[[autodoc]] Cosmos2_5_PredictBasePipeline
- all
- __call__

## CosmosPipelineOutput

[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
Expand Down
135 changes: 127 additions & 8 deletions scripts/convert_cosmos_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,55 @@
"""
# Cosmos 2 Predict
Download checkpoint
```bash
hf download nvidia/Cosmos-Predict2-2B-Text2Image
```
convert checkpoint
```bash
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_ckpt_path $transformer_ckpt_path \
--transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \
--text_encoder_path google-t5/t5-11b \
--tokenizer_path google-t5/t5-11b \
--vae_type wan2.1 \
--output_path converted/cosmos-p2-t2i-2b \
--save_pipeline
```
# Cosmos 2.5 Predict
Download checkpoint
```bash
hf download nvidia/Cosmos-Predict2.5-2B
```
Convert checkpoint
```bash
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Predict-Base-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/cosmos-p2.5-base-2b \
--save_pipeline
```
"""

import argparse
import pathlib
import sys
from typing import Any, Dict

import torch
from accelerate import init_empty_weights
from huggingface_hub import snapshot_download
from transformers import T5EncoderModel, T5TokenizerFast
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, T5EncoderModel, T5TokenizerFast

from diffusers import (
AutoencoderKLCosmos,
Expand All @@ -17,7 +61,9 @@
CosmosVideoToWorldPipeline,
EDMEulerScheduler,
FlowMatchEulerDiscreteScheduler,
UniPCMultistepScheduler,
)
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline


def remove_keys_(key: str, state_dict: Dict[str, Any]):
Expand Down Expand Up @@ -233,6 +279,25 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
"concat_padding_mask": True,
"extra_pos_embed_type": None,
},
"Cosmos-2.5-Predict-Base-2B": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 16,
"attention_head_dim": 128,
"num_layers": 28,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (1.0, 3.0, 3.0),
"concat_padding_mask": True,
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
"extra_pos_embed_type": None,
"use_crossattn_projection": True,
"crossattn_proj_in_channels": 100352,
"encoder_hidden_states_channels": 1024,
},
}

VAE_KEYS_RENAME_DICT = {
Expand Down Expand Up @@ -334,6 +399,9 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
elif "Cosmos-2.0" in transformer_type:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
elif "Cosmos-2.5" in transformer_type:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
else:
assert False

Expand All @@ -347,6 +415,7 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
new_key = new_key.removeprefix(PREFIX_KEY)
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
print(key, "->", new_key, flush=True)
update_state_dict_(original_state_dict, key, new_key)

for key in list(original_state_dict.keys()):
Expand All @@ -355,6 +424,21 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
continue
handler_fn_inplace(key, original_state_dict)

expected_keys = set(transformer.state_dict().keys())
mapped_keys = set(original_state_dict.keys())
missing_keys = expected_keys - mapped_keys
unexpected_keys = mapped_keys - expected_keys
if missing_keys:
print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr)
for k in missing_keys:
print(k)
sys.exit(1)
if unexpected_keys:
print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr)
for k in unexpected_keys:
print(k)
sys.exit(2)

transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer

Expand Down Expand Up @@ -444,17 +528,45 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")


def save_pipeline_cosmos2_5(args, transformer, vae):
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"

text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
text_encoder_path, torch_dtype="auto", device_map="cpu"
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

scheduler = UniPCMultistepScheduler(
use_karras_sigmas=True,
use_flow_sigmas=True,
prediction_type="flow_prediction",
sigma_max=200.0,
sigma_min=0.01,
)

pipe = Cosmos2_5_PredictBasePipeline(
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
vae=vae,
scheduler=scheduler,
safety_checker=lambda *args, **kwargs: None,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument(
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
"--vae_type", type=str, default="wan2.1", choices=["wan2.1", *list(VAE_CONFIGS.keys())], help="Type of VAE"
)
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--text_encoder_path", type=str, default=None)
parser.add_argument("--tokenizer_path", type=str, default=None)
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
Expand All @@ -477,8 +589,6 @@ def get_args():
if args.save_pipeline:
assert args.transformer_ckpt_path is not None
assert args.vae_type is not None
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None

if args.transformer_ckpt_path is not None:
weights_only = "Cosmos-1.0" in args.transformer_type
Expand All @@ -490,17 +600,26 @@ def get_args():
if args.vae_type is not None:
if "Cosmos-1.0" in args.transformer_type:
vae = convert_vae(args.vae_type)
else:
elif "Cosmos-2.0" in args.transformer_type or "Cosmos-2.5" in args.transformer_type:
vae = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
)
else:
raise AssertionError(f"{args.transformer_type} not supported")

if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")

if args.save_pipeline:
if "Cosmos-1.0" in args.transformer_type:
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
save_pipeline_cosmos_1_0(args, transformer, vae)
elif "Cosmos-2.0" in args.transformer_type:
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
save_pipeline_cosmos_2_0(args, transformer, vae)
elif "Cosmos-2.5" in args.transformer_type:
save_pipeline_cosmos2_5(args, transformer, vae)
else:
assert False
raise AssertionError(f"{args.transformer_type} not supported")
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@
"CogView4ControlPipeline",
"CogView4Pipeline",
"ConsisIDPipeline",
"Cosmos2_5_PredictBasePipeline",
"Cosmos2TextToImagePipeline",
"Cosmos2VideoToWorldPipeline",
"CosmosTextToWorldPipeline",
Expand Down Expand Up @@ -1175,6 +1176,7 @@
CogView4ControlPipeline,
CogView4Pipeline,
ConsisIDPipeline,
Cosmos2_5_PredictBasePipeline,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,
Expand Down
13 changes: 13 additions & 0 deletions src/diffusers/models/transformers/transformer_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,9 @@ def __init__(
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
concat_padding_mask: bool = True,
extra_pos_embed_type: Optional[str] = "learnable",
use_crossattn_projection: bool = False,
crossattn_proj_in_channels: int = 1024,
encoder_hidden_states_channels: int = 1024,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
Expand Down Expand Up @@ -485,6 +488,12 @@ def __init__(
hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False
)

if self.config.use_crossattn_projection:
self.crossattn_proj = nn.Sequential(
nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True),
nn.GELU(),
)

self.gradient_checkpointing = False

def forward(
Expand Down Expand Up @@ -524,6 +533,7 @@ def forward(
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w

hidden_states = self.patch_embed(hidden_states)
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]

Expand All @@ -546,6 +556,9 @@ def forward(
else:
assert False

if self.config.use_crossattn_projection:
encoder_hidden_states = self.crossattn_proj(encoder_hidden_states)

# 5. Transformer blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
_import_structure["consisid"] = ["ConsisIDPipeline"]
_import_structure["cosmos"] = [
"Cosmos2_5_PredictBasePipeline",
"Cosmos2TextToImagePipeline",
"CosmosTextToWorldPipeline",
"CosmosVideoToWorldPipeline",
Expand Down Expand Up @@ -622,6 +623,7 @@
StableDiffusionXLControlNetXSPipeline,
)
from .cosmos import (
Cosmos2_5_PredictBasePipeline,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/pipelines/cosmos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cosmos2_5_predict"] = [
"Cosmos2_5_PredictBasePipeline",
]
_import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"]
_import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
Expand All @@ -35,6 +38,9 @@
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_cosmos2_5_predict import (
Cosmos2_5_PredictBasePipeline,
)
from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline
from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline
Expand Down
Loading