|
45 | 45 | enable_full_determinism() |
46 | 46 |
|
47 | 47 |
|
| 48 | +# TODO: This standalone function maintains backward compatibility with pipeline tests |
| 49 | +# (tests/pipelines/test_pipelines_common.py) and will be refactored. |
| 50 | +def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]: |
| 51 | + """Create a dummy IP Adapter state dict for Flux transformer testing.""" |
| 52 | + ip_cross_attn_state_dict = {} |
| 53 | + key_id = 0 |
| 54 | + |
| 55 | + for name in model.attn_processors.keys(): |
| 56 | + if name.startswith("single_transformer_blocks"): |
| 57 | + continue |
| 58 | + |
| 59 | + joint_attention_dim = model.config["joint_attention_dim"] |
| 60 | + hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] |
| 61 | + sd = FluxIPAdapterAttnProcessor( |
| 62 | + hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 |
| 63 | + ).state_dict() |
| 64 | + ip_cross_attn_state_dict.update( |
| 65 | + { |
| 66 | + f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], |
| 67 | + f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], |
| 68 | + f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], |
| 69 | + f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], |
| 70 | + } |
| 71 | + ) |
| 72 | + key_id += 1 |
| 73 | + |
| 74 | + image_projection = ImageProjection( |
| 75 | + cross_attention_dim=model.config["joint_attention_dim"], |
| 76 | + image_embed_dim=( |
| 77 | + model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768 |
| 78 | + ), |
| 79 | + num_image_text_embeds=4, |
| 80 | + ) |
| 81 | + |
| 82 | + ip_image_projection_state_dict = {} |
| 83 | + sd = image_projection.state_dict() |
| 84 | + ip_image_projection_state_dict.update( |
| 85 | + { |
| 86 | + "proj.weight": sd["image_embeds.weight"], |
| 87 | + "proj.bias": sd["image_embeds.bias"], |
| 88 | + "norm.weight": sd["norm.weight"], |
| 89 | + "norm.bias": sd["norm.bias"], |
| 90 | + } |
| 91 | + ) |
| 92 | + |
| 93 | + del sd |
| 94 | + return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict} |
| 95 | + |
| 96 | + |
48 | 97 | class FluxTransformerTesterConfig: |
49 | 98 | model_class = FluxTransformer2DModel |
50 | 99 | pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" |
@@ -169,54 +218,7 @@ def modify_inputs_for_ip_adapter(self, model, inputs_dict): |
169 | 218 | return inputs_dict |
170 | 219 |
|
171 | 220 | def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]: |
172 | | - from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor |
173 | | - |
174 | | - ip_cross_attn_state_dict = {} |
175 | | - key_id = 0 |
176 | | - |
177 | | - for name in model.attn_processors.keys(): |
178 | | - if name.startswith("single_transformer_blocks"): |
179 | | - continue |
180 | | - |
181 | | - joint_attention_dim = model.config["joint_attention_dim"] |
182 | | - hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] |
183 | | - sd = FluxIPAdapterAttnProcessor( |
184 | | - hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 |
185 | | - ).state_dict() |
186 | | - ip_cross_attn_state_dict.update( |
187 | | - { |
188 | | - f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], |
189 | | - f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], |
190 | | - f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], |
191 | | - f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], |
192 | | - } |
193 | | - ) |
194 | | - |
195 | | - key_id += 1 |
196 | | - |
197 | | - image_projection = ImageProjection( |
198 | | - cross_attention_dim=model.config["joint_attention_dim"], |
199 | | - image_embed_dim=( |
200 | | - model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768 |
201 | | - ), |
202 | | - num_image_text_embeds=4, |
203 | | - ) |
204 | | - |
205 | | - ip_image_projection_state_dict = {} |
206 | | - sd = image_projection.state_dict() |
207 | | - ip_image_projection_state_dict.update( |
208 | | - { |
209 | | - "proj.weight": sd["image_embeds.weight"], |
210 | | - "proj.bias": sd["image_embeds.bias"], |
211 | | - "norm.weight": sd["norm.weight"], |
212 | | - "norm.bias": sd["norm.bias"], |
213 | | - } |
214 | | - ) |
215 | | - |
216 | | - del sd |
217 | | - ip_state_dict = {} |
218 | | - ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) |
219 | | - return ip_state_dict |
| 221 | + return create_flux_ip_adapter_state_dict(model) |
220 | 222 |
|
221 | 223 |
|
222 | 224 | class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin): |
|
0 commit comments