|
17 | 17 | import os |
18 | 18 | import tempfile |
19 | 19 | import unittest |
20 | | -from types import SimpleNamespace |
21 | | -from typing import List |
22 | 20 |
|
23 | 21 | import numpy as np |
24 | 22 | import torch |
| 23 | +from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration |
25 | 24 |
|
26 | 25 | from diffusers import AutoencoderKLWan, Cosmos_2_5_PredictBase, CosmosTransformer3DModel, FlowUniPCMultistepScheduler |
27 | 26 |
|
|
34 | 33 | enable_full_determinism() |
35 | 34 |
|
36 | 35 |
|
37 | | -class DummyPredictTokenizer: |
38 | | - model_input_names = ["input_ids"] |
39 | | - |
40 | | - def __init__(self, vocab_size: int = 128): |
41 | | - self.vocab_size = vocab_size |
42 | | - |
43 | | - @classmethod |
44 | | - def from_pretrained(cls, *args, **kwargs): |
45 | | - return cls() |
46 | | - |
47 | | - def apply_chat_template( |
48 | | - self, |
49 | | - conversations: List[dict], |
50 | | - tokenize: bool = True, |
51 | | - add_generation_prompt: bool = False, |
52 | | - add_vision_id: bool = False, |
53 | | - max_length: int = 16, |
54 | | - truncation: bool = True, |
55 | | - padding: str = "max_length", |
56 | | - ): |
57 | | - return list(range(max_length)) |
58 | | - |
59 | | - def save_pretrained(self, save_directory: str): |
60 | | - os.makedirs(save_directory, exist_ok=True) |
61 | | - with open(os.path.join(save_directory, "tokenizer_config.json"), "w") as f: |
62 | | - json.dump({"vocab_size": self.vocab_size}, f) |
63 | | - |
64 | | - |
65 | | -class DummyPredictTextEncoder(torch.nn.Module): |
66 | | - config_name = "config.json" |
67 | | - |
68 | | - def __init__(self, vocab_size: int = 128, hidden_size: int = 16): |
69 | | - super().__init__() |
70 | | - self.emb = torch.nn.Embedding(vocab_size, hidden_size) |
71 | | - self.proj = torch.nn.Linear(hidden_size, hidden_size) |
72 | | - self.config = SimpleNamespace(hidden_size=hidden_size) |
73 | | - |
74 | | - @property |
75 | | - def dtype(self): |
76 | | - return next(self.parameters()).dtype |
77 | | - |
78 | | - @classmethod |
79 | | - def from_pretrained(cls, save_directory: str, **kwargs): |
80 | | - return cls() |
81 | | - |
82 | | - def save_pretrained(self, save_directory: str, safe_serialization: bool = False): |
83 | | - os.makedirs(save_directory, exist_ok=True) |
84 | | - torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) |
85 | | - with open(os.path.join(save_directory, self.config_name), "w") as f: |
86 | | - json.dump({"vocab_size": self.emb.num_embeddings, "hidden_size": self.emb.embedding_dim}, f) |
87 | | - |
88 | | - def forward(self, input_ids: torch.LongTensor, output_hidden_states: bool = False, **kwargs): |
89 | | - hidden = self.emb(input_ids) |
90 | | - hidden = self.proj(hidden) |
91 | | - hidden_states = ( |
92 | | - hidden, |
93 | | - hidden * 0.5, |
94 | | - hidden * 0.25, |
95 | | - ) |
96 | | - return SimpleNamespace(hidden_states=hidden_states) |
97 | | - |
98 | | - |
99 | 36 | class Cosmos_2_5_PredictBaseWrapper(Cosmos_2_5_PredictBase): |
100 | 37 | @staticmethod |
101 | 38 | def from_pretrained(*args, **kwargs): |
@@ -154,8 +91,11 @@ def get_dummy_components(self): |
154 | 91 | torch.manual_seed(0) |
155 | 92 | scheduler = FlowUniPCMultistepScheduler() |
156 | 93 |
|
157 | | - text_encoder = DummyPredictTextEncoder(hidden_size=16) |
158 | | - tokenizer = DummyPredictTokenizer() |
| 94 | + # NOTE: using Qwen2 VL instead for tests (reason1 is based on 2.5) |
| 95 | + text_encoder = Qwen2VLForConditionalGeneration.from_pretrained( |
| 96 | + "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", |
| 97 | + ) |
| 98 | + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") |
159 | 99 |
|
160 | 100 | components = { |
161 | 101 | "transformer": transformer, |
|
0 commit comments