diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py index c662c1657513..dea370c65c58 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py @@ -145,7 +145,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: frames, height * width, query.dtype, query.device, batch_size=batch_size ) - x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) + x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask.unsqueeze(1)) # batch_size, 1, frames * height * width, channels