Skip to content

Commit f6519ec

Browse files
cosmos predict2.5 base: convert chkpt & pipeline
- New scheduler: scheduling_flow_unipc_multistep.py - Changes to TransformerCosmos for text embeddings via crossattn_proj
1 parent 5e48f46 commit f6519ec

File tree

7 files changed

+2466
-8
lines changed

7 files changed

+2466
-8
lines changed

scripts/convert_cosmos_to_diffusers.py

Lines changed: 120 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,55 @@
1+
"""
2+
# Cosmos 2 Predict
3+
4+
Download checkpoint
5+
```bash
6+
hf download nvidia/Cosmos-Predict2-2B-Text2Image
7+
```
8+
9+
convert checkpoint
10+
```bash
11+
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt
12+
13+
python scripts/convert_cosmos_to_diffusers.py \
14+
--transformer_ckpt_path $transformer_ckpt_path \
15+
--transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \
16+
--text_encoder_path google-t5/t5-11b \
17+
--tokenizer_path google-t5/t5-11b \
18+
--vae_type wan2.1 \
19+
--output_path converted/cosmos-p2-t2i-2b \
20+
--save_pipeline
21+
```
22+
23+
# Cosmos 2.5 Predict
24+
25+
Download checkpoint
26+
```bash
27+
hf download nvidia/Cosmos-Predict2.5-2B
28+
```
29+
30+
Convert checkpoint
31+
```bash
32+
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
33+
34+
python scripts/convert_cosmos_to_diffusers.py \
35+
--transformer_type Cosmos-2.5-Predict-Base-2B \
36+
--transformer_ckpt_path $transformer_ckpt_path \
37+
--vae_type wan2.1 \
38+
--output_path converted/cosmos-p2.5-base-2b \
39+
--save_pipeline
40+
```
41+
42+
"""
43+
144
import argparse
245
import pathlib
46+
import sys
347
from typing import Any, Dict
448

549
import torch
650
from accelerate import init_empty_weights
751
from huggingface_hub import snapshot_download
8-
from transformers import T5EncoderModel, T5TokenizerFast
52+
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, T5EncoderModel, T5TokenizerFast
953

1054
from diffusers import (
1155
AutoencoderKLCosmos,
@@ -18,6 +62,7 @@
1862
EDMEulerScheduler,
1963
FlowMatchEulerDiscreteScheduler,
2064
)
65+
from diffusers.pipelines.cosmos.pipeline_cosmos25_predict import Cosmos25PredictBase
2166

2267

2368
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -233,6 +278,25 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
233278
"concat_padding_mask": True,
234279
"extra_pos_embed_type": None,
235280
},
281+
"Cosmos-2.5-Predict-Base-2B": {
282+
"in_channels": 16 + 1,
283+
"out_channels": 16,
284+
"num_attention_heads": 16,
285+
"attention_head_dim": 128,
286+
"num_layers": 28,
287+
"mlp_ratio": 4.0,
288+
"text_embed_dim": 1024,
289+
"adaln_lora_dim": 256,
290+
"max_size": (128, 240, 240),
291+
"patch_size": (1, 2, 2),
292+
"rope_scale": (1.0, 3.0, 3.0),
293+
"concat_padding_mask": True,
294+
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
295+
"extra_pos_embed_type": None,
296+
"use_crossattn_projection": True,
297+
"crossattn_proj_in_channels": 100352,
298+
"encoder_hidden_states_channels": 1024,
299+
},
236300
}
237301

238302
VAE_KEYS_RENAME_DICT = {
@@ -334,6 +398,9 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
334398
elif "Cosmos-2.0" in transformer_type:
335399
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
336400
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
401+
elif "Cosmos-2.5" in transformer_type:
402+
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
403+
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
337404
else:
338405
assert False
339406

@@ -347,6 +414,7 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
347414
new_key = new_key.removeprefix(PREFIX_KEY)
348415
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
349416
new_key = new_key.replace(replace_key, rename_key)
417+
print(key, "->", new_key, flush=True)
350418
update_state_dict_(original_state_dict, key, new_key)
351419

352420
for key in list(original_state_dict.keys()):
@@ -355,6 +423,21 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
355423
continue
356424
handler_fn_inplace(key, original_state_dict)
357425

426+
expected_keys = set(transformer.state_dict().keys())
427+
mapped_keys = set(original_state_dict.keys())
428+
missing_keys = expected_keys - mapped_keys
429+
unexpected_keys = mapped_keys - expected_keys
430+
if missing_keys:
431+
print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr)
432+
for k in missing_keys:
433+
print(k)
434+
sys.exit(1)
435+
if unexpected_keys:
436+
print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr)
437+
for k in unexpected_keys:
438+
print(k)
439+
sys.exit(2)
440+
358441
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
359442
return transformer
360443

@@ -444,17 +527,39 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
444527
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
445528

446529

530+
def save_pipeline_cosmos_2_5(args, transformer, vae):
531+
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
532+
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
533+
534+
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
535+
text_encoder_path, torch_dtype="auto", device_map="cpu"
536+
)
537+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
538+
539+
scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
540+
541+
pipe = Cosmos25PredictBase(
542+
text_encoder=text_encoder,
543+
tokenizer=tokenizer,
544+
transformer=transformer,
545+
vae=vae,
546+
scheduler=scheduler,
547+
safety_checker=lambda *args, **kwargs: None,
548+
)
549+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
550+
551+
447552
def get_args():
448553
parser = argparse.ArgumentParser()
449554
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
450555
parser.add_argument(
451556
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
452557
)
453558
parser.add_argument(
454-
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
559+
"--vae_type", type=str, default="wan2.1", choices=["wan2.1", *list(VAE_CONFIGS.keys())], help="Type of VAE"
455560
)
456-
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
457-
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
561+
parser.add_argument("--text_encoder_path", type=str, default=None)
562+
parser.add_argument("--tokenizer_path", type=str, default=None)
458563
parser.add_argument("--save_pipeline", action="store_true")
459564
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
460565
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
@@ -477,8 +582,6 @@ def get_args():
477582
if args.save_pipeline:
478583
assert args.transformer_ckpt_path is not None
479584
assert args.vae_type is not None
480-
assert args.text_encoder_path is not None
481-
assert args.tokenizer_path is not None
482585

483586
if args.transformer_ckpt_path is not None:
484587
weights_only = "Cosmos-1.0" in args.transformer_type
@@ -490,17 +593,26 @@ def get_args():
490593
if args.vae_type is not None:
491594
if "Cosmos-1.0" in args.transformer_type:
492595
vae = convert_vae(args.vae_type)
493-
else:
596+
elif "Cosmos-2.0" in args.transformer_type or "Cosmos-2.5" in args.transformer_type:
494597
vae = AutoencoderKLWan.from_pretrained(
495598
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
496599
)
600+
else:
601+
raise AssertionError(f"{args.transformer_type} not supported")
602+
497603
if not args.save_pipeline:
498604
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
499605

500606
if args.save_pipeline:
501607
if "Cosmos-1.0" in args.transformer_type:
608+
assert args.text_encoder_path is not None
609+
assert args.tokenizer_path is not None
502610
save_pipeline_cosmos_1_0(args, transformer, vae)
503611
elif "Cosmos-2.0" in args.transformer_type:
612+
assert args.text_encoder_path is not None
613+
assert args.tokenizer_path is not None
504614
save_pipeline_cosmos_2_0(args, transformer, vae)
615+
elif "Cosmos-2.5" in args.transformer_type:
616+
save_pipeline_cosmos_2_5(args, transformer, vae)
505617
else:
506-
assert False
618+
raise AssertionError(f"{args.transformer_type} not supported")

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@
346346
"FlowMatchEulerDiscreteScheduler",
347347
"FlowMatchHeunDiscreteScheduler",
348348
"FlowMatchLCMScheduler",
349+
"FlowUniPCMultistepScheduler",
349350
"HeunDiscreteScheduler",
350351
"IPNDMScheduler",
351352
"KarrasVeScheduler",
@@ -1075,6 +1076,7 @@
10751076
FlowMatchEulerDiscreteScheduler,
10761077
FlowMatchHeunDiscreteScheduler,
10771078
FlowMatchLCMScheduler,
1079+
FlowUniPCMultistepScheduler,
10781080
HeunDiscreteScheduler,
10791081
IPNDMScheduler,
10801082
KarrasVeScheduler,

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,9 @@ def __init__(
439439
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
440440
concat_padding_mask: bool = True,
441441
extra_pos_embed_type: Optional[str] = "learnable",
442+
use_crossattn_projection: bool = False,
443+
crossattn_proj_in_channels: int = 1024,
444+
encoder_hidden_states_channels: int = 1024,
442445
) -> None:
443446
super().__init__()
444447
hidden_size = num_attention_heads * attention_head_dim
@@ -485,6 +488,13 @@ def __init__(
485488
hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False
486489
)
487490

491+
self.use_crossattn_projection = use_crossattn_projection
492+
if self.use_crossattn_projection:
493+
self.crossattn_proj = nn.Sequential(
494+
nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True),
495+
nn.GELU(),
496+
)
497+
488498
self.gradient_checkpointing = False
489499

490500
def forward(
@@ -524,6 +534,7 @@ def forward(
524534
post_patch_num_frames = num_frames // p_t
525535
post_patch_height = height // p_h
526536
post_patch_width = width // p_w
537+
527538
hidden_states = self.patch_embed(hidden_states)
528539
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
529540

@@ -546,6 +557,9 @@ def forward(
546557
else:
547558
assert False
548559

560+
if self.use_crossattn_projection:
561+
encoder_hidden_states = self.crossattn_proj(encoder_hidden_states)
562+
549563
# 5. Transformer blocks
550564
for block in self.transformer_blocks:
551565
if torch.is_grad_enabled() and self.gradient_checkpointing:

0 commit comments

Comments
 (0)