fix: flash_attn_3_func value unpacking in _wrapped_flash_attn_3 w th compile #12851
+1
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
The
_wrapped_flash_attn_3function unconditionally unpacks bothoutandlsefrom the return value:However, it was not passing
return_attn_probs=Trueto request the tuple return. Since Dao-AILab/flash-attention@203b9b3,flash_attn_funcreturns onlyoutby default, causing:How does this pr fixes it
Adds
return_attn_probs=Trueto theflash_attn_3_funccall, consistent with how_flash_attention_3_hubhandles.Reproduction
and bring your own flash attention build, to repro this i built it from source @ Dao-AILab/flash-attention@ac9b5f1
Alternative
The wrapper seems to exist to support fa3 as custom op. However, fa3 now has native torch.compile support as of Dao-AILab/flash-attention@c7697bb. This might be making
_wrapped_flash_attn_3redundant, tho i dont really know if that is the only reason.Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?