1+ """
2+ # Cosmos 2 Predict
3+
4+ Download checkpoint
5+ ```bash
6+ hf download nvidia/Cosmos-Predict2-2B-Text2Image
7+ ```
8+
9+ convert checkpoint
10+ ```bash
11+ transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt
12+
13+ python scripts/convert_cosmos_to_diffusers.py \
14+ --transformer_ckpt_path $transformer_ckpt_path \
15+ --transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \
16+ --text_encoder_path google-t5/t5-11b \
17+ --tokenizer_path google-t5/t5-11b \
18+ --vae_type wan2.1 \
19+ --output_path converted/cosmos-p2-t2i-2b \
20+ --save_pipeline
21+ ```
22+
23+ # Cosmos 2.5 Predict
24+
25+ Download checkpoint
26+ ```bash
27+ hf download nvidia/Cosmos-Predict2.5-2B
28+ ```
29+
30+ Convert checkpoint
31+ ```bash
32+ transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
33+
34+ python scripts/convert_cosmos_to_diffusers.py \
35+ --transformer_type Cosmos-2.5-Predict-Base-2B \
36+ --transformer_ckpt_path $transformer_ckpt_path \
37+ --vae_type wan2.1 \
38+ --output_path converted/cosmos-p2.5-base-2b \
39+ --save_pipeline
40+ ```
41+
42+ """
43+
144import argparse
245import pathlib
46+ import sys
347from typing import Any , Dict
448
549import torch
650from accelerate import init_empty_weights
751from huggingface_hub import snapshot_download
8- from transformers import T5EncoderModel , T5TokenizerFast
52+ from transformers import AutoTokenizer , Qwen2_5_VLForConditionalGeneration , T5EncoderModel , T5TokenizerFast
953
1054from diffusers import (
1155 AutoencoderKLCosmos ,
1862 EDMEulerScheduler ,
1963 FlowMatchEulerDiscreteScheduler ,
2064)
65+ from diffusers .pipelines .cosmos .pipeline_cosmos25_predict import Cosmos25PredictBase
2166
2267
2368def remove_keys_ (key : str , state_dict : Dict [str , Any ]):
@@ -233,6 +278,25 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
233278 "concat_padding_mask" : True ,
234279 "extra_pos_embed_type" : None ,
235280 },
281+ "Cosmos-2.5-Predict-Base-2B" : {
282+ "in_channels" : 16 + 1 ,
283+ "out_channels" : 16 ,
284+ "num_attention_heads" : 16 ,
285+ "attention_head_dim" : 128 ,
286+ "num_layers" : 28 ,
287+ "mlp_ratio" : 4.0 ,
288+ "text_embed_dim" : 1024 ,
289+ "adaln_lora_dim" : 256 ,
290+ "max_size" : (128 , 240 , 240 ),
291+ "patch_size" : (1 , 2 , 2 ),
292+ "rope_scale" : (1.0 , 3.0 , 3.0 ),
293+ "concat_padding_mask" : True ,
294+ # NOTE: source config has pos_emb_learnable: 'True' - but params are missing
295+ "extra_pos_embed_type" : None ,
296+ "use_crossattn_projection" : True ,
297+ "crossattn_proj_in_channels" : 100352 ,
298+ "encoder_hidden_states_channels" : 1024 ,
299+ },
236300}
237301
238302VAE_KEYS_RENAME_DICT = {
@@ -334,6 +398,9 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
334398 elif "Cosmos-2.0" in transformer_type :
335399 TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
336400 TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
401+ elif "Cosmos-2.5" in transformer_type :
402+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
403+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
337404 else :
338405 assert False
339406
@@ -347,6 +414,7 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
347414 new_key = new_key .removeprefix (PREFIX_KEY )
348415 for replace_key , rename_key in TRANSFORMER_KEYS_RENAME_DICT .items ():
349416 new_key = new_key .replace (replace_key , rename_key )
417+ print (key , "->" , new_key , flush = True )
350418 update_state_dict_ (original_state_dict , key , new_key )
351419
352420 for key in list (original_state_dict .keys ()):
@@ -355,6 +423,21 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
355423 continue
356424 handler_fn_inplace (key , original_state_dict )
357425
426+ expected_keys = set (transformer .state_dict ().keys ())
427+ mapped_keys = set (original_state_dict .keys ())
428+ missing_keys = expected_keys - mapped_keys
429+ unexpected_keys = mapped_keys - expected_keys
430+ if missing_keys :
431+ print (f"ERROR: missing keys ({ len (missing_keys )} from state_dict:" , flush = True , file = sys .stderr )
432+ for k in missing_keys :
433+ print (k )
434+ sys .exit (1 )
435+ if unexpected_keys :
436+ print (f"ERROR: unexpected keys ({ len (unexpected_keys )} ) from state_dict:" , flush = True , file = sys .stderr )
437+ for k in unexpected_keys :
438+ print (k )
439+ sys .exit (2 )
440+
358441 transformer .load_state_dict (original_state_dict , strict = True , assign = True )
359442 return transformer
360443
@@ -444,17 +527,39 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
444527 pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
445528
446529
530+ def save_pipeline_cosmos_2_5 (args , transformer , vae ):
531+ text_encoder_path = args .text_encoder_path or "nvidia/Cosmos-Reason1-7B"
532+ tokenizer_path = args .tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
533+
534+ text_encoder = Qwen2_5_VLForConditionalGeneration .from_pretrained (
535+ text_encoder_path , torch_dtype = "auto" , device_map = "cpu"
536+ )
537+ tokenizer = AutoTokenizer .from_pretrained (tokenizer_path )
538+
539+ scheduler = FlowMatchEulerDiscreteScheduler (use_karras_sigmas = True )
540+
541+ pipe = Cosmos25PredictBase (
542+ text_encoder = text_encoder ,
543+ tokenizer = tokenizer ,
544+ transformer = transformer ,
545+ vae = vae ,
546+ scheduler = scheduler ,
547+ safety_checker = lambda * args , ** kwargs : None ,
548+ )
549+ pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
550+
551+
447552def get_args ():
448553 parser = argparse .ArgumentParser ()
449554 parser .add_argument ("--transformer_type" , type = str , default = None , choices = list (TRANSFORMER_CONFIGS .keys ()))
450555 parser .add_argument (
451556 "--transformer_ckpt_path" , type = str , default = None , help = "Path to original transformer checkpoint"
452557 )
453558 parser .add_argument (
454- "--vae_type" , type = str , default = None , choices = ["none " , * list (VAE_CONFIGS .keys ())], help = "Type of VAE"
559+ "--vae_type" , type = str , default = "wan2.1" , choices = ["wan2.1 " , * list (VAE_CONFIGS .keys ())], help = "Type of VAE"
455560 )
456- parser .add_argument ("--text_encoder_path" , type = str , default = "google-t5/t5-11b" )
457- parser .add_argument ("--tokenizer_path" , type = str , default = "google-t5/t5-11b" )
561+ parser .add_argument ("--text_encoder_path" , type = str , default = None )
562+ parser .add_argument ("--tokenizer_path" , type = str , default = None )
458563 parser .add_argument ("--save_pipeline" , action = "store_true" )
459564 parser .add_argument ("--output_path" , type = str , required = True , help = "Path where converted model should be saved" )
460565 parser .add_argument ("--dtype" , default = "bf16" , help = "Torch dtype to save the transformer in." )
@@ -477,8 +582,6 @@ def get_args():
477582 if args .save_pipeline :
478583 assert args .transformer_ckpt_path is not None
479584 assert args .vae_type is not None
480- assert args .text_encoder_path is not None
481- assert args .tokenizer_path is not None
482585
483586 if args .transformer_ckpt_path is not None :
484587 weights_only = "Cosmos-1.0" in args .transformer_type
@@ -490,17 +593,26 @@ def get_args():
490593 if args .vae_type is not None :
491594 if "Cosmos-1.0" in args .transformer_type :
492595 vae = convert_vae (args .vae_type )
493- else :
596+ elif "Cosmos-2.0" in args . transformer_type or "Cosmos-2.5" in args . transformer_type :
494597 vae = AutoencoderKLWan .from_pretrained (
495598 "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" , subfolder = "vae" , torch_dtype = torch .float32
496599 )
600+ else :
601+ raise AssertionError (f"{ args .transformer_type } not supported" )
602+
497603 if not args .save_pipeline :
498604 vae .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
499605
500606 if args .save_pipeline :
501607 if "Cosmos-1.0" in args .transformer_type :
608+ assert args .text_encoder_path is not None
609+ assert args .tokenizer_path is not None
502610 save_pipeline_cosmos_1_0 (args , transformer , vae )
503611 elif "Cosmos-2.0" in args .transformer_type :
612+ assert args .text_encoder_path is not None
613+ assert args .tokenizer_path is not None
504614 save_pipeline_cosmos_2_0 (args , transformer , vae )
615+ elif "Cosmos-2.5" in args .transformer_type :
616+ save_pipeline_cosmos_2_5 (args , transformer , vae )
505617 else :
506- assert False
618+ raise AssertionError ( f" { args . transformer_type } not supported" )
0 commit comments