Skip to content

Commit c366b5a

Browse files
committed
update
1 parent 0fdd9d3 commit c366b5a

File tree

1 file changed

+50
-48
lines changed

1 file changed

+50
-48
lines changed

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,55 @@
4545
enable_full_determinism()
4646

4747

48+
# TODO: This standalone function maintains backward compatibility with pipeline tests
49+
# (tests/pipelines/test_pipelines_common.py) and will be refactored.
50+
def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
51+
"""Create a dummy IP Adapter state dict for Flux transformer testing."""
52+
ip_cross_attn_state_dict = {}
53+
key_id = 0
54+
55+
for name in model.attn_processors.keys():
56+
if name.startswith("single_transformer_blocks"):
57+
continue
58+
59+
joint_attention_dim = model.config["joint_attention_dim"]
60+
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
61+
sd = FluxIPAdapterAttnProcessor(
62+
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
63+
).state_dict()
64+
ip_cross_attn_state_dict.update(
65+
{
66+
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
67+
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
68+
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
69+
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
70+
}
71+
)
72+
key_id += 1
73+
74+
image_projection = ImageProjection(
75+
cross_attention_dim=model.config["joint_attention_dim"],
76+
image_embed_dim=(
77+
model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768
78+
),
79+
num_image_text_embeds=4,
80+
)
81+
82+
ip_image_projection_state_dict = {}
83+
sd = image_projection.state_dict()
84+
ip_image_projection_state_dict.update(
85+
{
86+
"proj.weight": sd["image_embeds.weight"],
87+
"proj.bias": sd["image_embeds.bias"],
88+
"norm.weight": sd["norm.weight"],
89+
"norm.bias": sd["norm.bias"],
90+
}
91+
)
92+
93+
del sd
94+
return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}
95+
96+
4897
class FluxTransformerTesterConfig:
4998
model_class = FluxTransformer2DModel
5099
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
@@ -169,54 +218,7 @@ def modify_inputs_for_ip_adapter(self, model, inputs_dict):
169218
return inputs_dict
170219

171220
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
172-
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
173-
174-
ip_cross_attn_state_dict = {}
175-
key_id = 0
176-
177-
for name in model.attn_processors.keys():
178-
if name.startswith("single_transformer_blocks"):
179-
continue
180-
181-
joint_attention_dim = model.config["joint_attention_dim"]
182-
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
183-
sd = FluxIPAdapterAttnProcessor(
184-
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
185-
).state_dict()
186-
ip_cross_attn_state_dict.update(
187-
{
188-
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
189-
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
190-
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
191-
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
192-
}
193-
)
194-
195-
key_id += 1
196-
197-
image_projection = ImageProjection(
198-
cross_attention_dim=model.config["joint_attention_dim"],
199-
image_embed_dim=(
200-
model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768
201-
),
202-
num_image_text_embeds=4,
203-
)
204-
205-
ip_image_projection_state_dict = {}
206-
sd = image_projection.state_dict()
207-
ip_image_projection_state_dict.update(
208-
{
209-
"proj.weight": sd["image_embeds.weight"],
210-
"proj.bias": sd["image_embeds.bias"],
211-
"norm.weight": sd["norm.weight"],
212-
"norm.bias": sd["norm.bias"],
213-
}
214-
)
215-
216-
del sd
217-
ip_state_dict = {}
218-
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
219-
return ip_state_dict
221+
return create_flux_ip_adapter_state_dict(model)
220222

221223

222224
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):

0 commit comments

Comments
 (0)