From abe7d2424c5a4ac68fdccbcb5b77416172f7a9e2 Mon Sep 17 00:00:00 2001 From: badayvedat <54285744+badayvedat@users.noreply.github.com> Date: Wed, 17 Dec 2025 00:15:58 +0000 Subject: [PATCH 1/2] fix: flash_attn_3_func return value unpacking in _wrapped_flash_attn_3 --- src/diffusers/models/attention_dispatch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 310c44457c27..988c9995e9a7 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -657,6 +657,7 @@ def _wrapped_flash_attn_3( pack_gqa=pack_gqa, deterministic=deterministic, sm_margin=sm_margin, + return_attn_probs=True, ) lse = lse.permute(0, 2, 1) return out, lse From 3de380d3266586a75c5176d7a72a52ad3745ef0d Mon Sep 17 00:00:00 2001 From: badayvedat <54285744+badayvedat@users.noreply.github.com> Date: Wed, 17 Dec 2025 23:58:02 +0000 Subject: [PATCH 2/2] refactor: remove old wrapper --- src/diffusers/models/attention_dispatch.py | 82 ++-------------------- 1 file changed, 7 insertions(+), 75 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 988c9995e9a7..82f7e1f8294e 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -616,79 +616,6 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: raise -# ===== torch op registrations ===== -# Registrations are required for fullgraph tracing compatibility -# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding -# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 -@_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") -def _wrapped_flash_attn_3( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale: Optional[float] = None, - causal: bool = False, - qv: Optional[torch.Tensor] = None, - q_descale: Optional[torch.Tensor] = None, - k_descale: Optional[torch.Tensor] = None, - v_descale: Optional[torch.Tensor] = None, - attention_chunk: int = 0, - softcap: float = 0.0, - num_splits: int = 1, - pack_gqa: Optional[bool] = None, - deterministic: bool = False, - sm_margin: int = 0, -) -> Tuple[torch.Tensor, torch.Tensor]: - # Hardcoded for now because pytorch does not support tuple/int type hints - window_size = (-1, -1) - out, lse, *_ = flash_attn_3_func( - q=q, - k=k, - v=v, - softmax_scale=softmax_scale, - causal=causal, - qv=qv, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - window_size=window_size, - attention_chunk=attention_chunk, - softcap=softcap, - num_splits=num_splits, - pack_gqa=pack_gqa, - deterministic=deterministic, - sm_margin=sm_margin, - return_attn_probs=True, - ) - lse = lse.permute(0, 2, 1) - return out, lse - - -@_register_fake("_diffusers_flash_attn_3::_flash_attn_forward") -def _( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale: Optional[float] = None, - causal: bool = False, - qv: Optional[torch.Tensor] = None, - q_descale: Optional[torch.Tensor] = None, - k_descale: Optional[torch.Tensor] = None, - v_descale: Optional[torch.Tensor] = None, - attention_chunk: int = 0, - softcap: float = 0.0, - num_splits: int = 1, - pack_gqa: Optional[bool] = None, - deterministic: bool = False, - sm_margin: int = 0, -) -> Tuple[torch.Tensor, torch.Tensor]: - window_size = (-1, -1) # noqa: F841 - # A lot of the parameters here are not yet used in any way within diffusers. - # We can safely ignore for now and keep the fake op shape propagation simple. - batch_size, seq_len, num_heads, head_dim = q.shape - lse_shape = (batch_size, seq_len, num_heads) - return torch.empty_like(q), q.new_empty(lse_shape) - - # ===== Helper functions to use attention backends with templated CP autograd functions ===== @@ -1618,14 +1545,19 @@ def _flash_attention_3( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - out, lse = _wrapped_flash_attn_3( + out = flash_attn_3_func( q=query, k=key, v=value, softmax_scale=scale, causal=is_causal, + return_attn_probs=return_lse, ) - return (out, lse) if return_lse else out + if return_lse: + # out is (output, lse) tuple when return_attn_probs=True + # lse needs to be permuted from (batch, heads, seq) to (batch, seq, heads) + return out[0], out[1].permute(0, 2, 1) + return out @_AttentionBackendRegistry.register(