Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 221 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2785,6 +2785,227 @@ def __call__(
return hidden_states


class MemoryEfficientSelectiveAttnProcessor2_0:
r"""
Processor for implementing selective attention (https://arxiv.org/abs/2410.02703) with optional token pruning.

Args:
selection_head_idx (`int`, defaults to 0):
Index of the attention head to use for computing selection scores.
bidirectional (`bool`, defaults to `True`):
Whether to use bidirectional masking (for diffusion) or causal masking (for LLMs).
enable_pruning (`bool`, defaults to `False`):
Whether to remove heavily masked tokens from computation.
pruning_threshold (`float`, defaults to 1.0):
Tokens with masking scores above this threshold are pruned.
masking_strength (`float`, defaults to 0.1):
Scaling factor for the masking values. Lower values = gentler masking.
Set to 0.0 to disable selective masking entirely (useful for debugging).
"""

def __init__(
self,
selection_head_idx: int = 0,
bidirectional: bool = True,
enable_pruning: bool = False,
pruning_threshold: float = 1.0,
masking_strength: float = 0.1,
):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"MemoryEfficientSelectiveAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.selection_head_idx = selection_head_idx
self.bidirectional = bidirectional
self.enable_pruning = enable_pruning
self.pruning_threshold = pruning_threshold
self.masking_strength = masking_strength

def _compute_selection_scores(
self,
query: torch.Tensor,
key: torch.Tensor,
selection_head_idx: int,
) -> torch.Tensor:
"""Compute selection scores S from the specified attention head."""
seq_len = query.shape[2]
head_dim = query.shape[3]

q_select = query[:, selection_head_idx, :, :]
k_select = key[:, selection_head_idx, :, :]

scale = head_dim ** -0.5
selection_logits = torch.matmul(q_select, k_select.transpose(-2, -1)) * scale

# Apply constraints: ReLU, zero first column, zero diagonal
S = torch.relu(selection_logits)
S[:, :, 0] = 0
diagonal_mask = torch.eye(seq_len, device=S.device, dtype=torch.bool)
S = S.masked_fill(diagonal_mask.unsqueeze(0), 0)

return S

def _compute_masking_matrix(
self,
S: torch.Tensor,
bidirectional: bool = True,
) -> torch.Tensor:
"""Compute the accumulated masking matrix F from selection scores S."""
if bidirectional:
# Symmetric: F[i,j] = sum_{k!=i} S[k,j]
total_votes = S.sum(dim=-2, keepdim=True)
F = total_votes - S
else:
# Causal: F[i,j] = sum_{k<i} S[k,j]
cumsum_S = torch.cumsum(S, dim=-2)
F = cumsum_S - S

return F

def __call__(
self,
attn: "Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = (
"The `scale` argument is deprecated and will be ignored. Please remove it, "
"as passing it will raise an error in the future. `scale` should directly be "
"passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
)
deprecate("scale", "1.0.0", deprecation_message)

residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states)

is_self_attention = encoder_hidden_states is None
if is_self_attention:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# Selective attention: only apply to self-attention with at least 2 tokens
if is_self_attention and query.shape[2] > 1:
if self.selection_head_idx >= attn.heads:
logger.warning(
f"selection_head_idx ({self.selection_head_idx}) >= num_heads ({attn.heads}). "
f"Falling back to head 0."
)
selection_head_idx = 0
else:
selection_head_idx = self.selection_head_idx

# Compute selection scores and masking matrix
S = self._compute_selection_scores(query, key, selection_head_idx)
F_mask = self._compute_masking_matrix(S, self.bidirectional)
F_mask = F_mask.unsqueeze(1).expand(-1, attn.heads, -1, -1)

# Normalize F_mask to prevent explosion for long sequences
# Scale by 1/seq_len to make masking strength sequence-length independent
seq_len = query.shape[2]
F_mask = F_mask / seq_len

# Apply masking strength scaling
# Lower strength = gentler masking, preserving more of the original attention
F_mask = F_mask * self.masking_strength

# Combine with existing attention mask (F is subtracted from logits)
selective_mask = -F_mask
if attention_mask is not None:
attention_mask = attention_mask + selective_mask
else:
attention_mask = selective_mask

# Optional pruning: remove heavily masked tokens
if self.enable_pruning:
max_masking_per_token = F_mask[:, 0, :, :].max(dim=-2)[0]
keep_mask = max_masking_per_token < self.pruning_threshold

# Ensure at least 10% of tokens are kept
min_tokens = max(1, int(sequence_length * 0.1))
tokens_to_keep = keep_mask.sum(dim=-1)

for b in range(batch_size):
if tokens_to_keep[b] < min_tokens:
_, indices = max_masking_per_token[b].topk(min_tokens, largest=False)
keep_mask[b] = False
keep_mask[b, indices] = True

# Apply pruning mask
if not keep_mask.all():
prune_mask = ~keep_mask
prune_mask = prune_mask.unsqueeze(1).unsqueeze(2)
prune_mask = prune_mask.expand(-1, attn.heads, query.shape[2], -1)

if attention_mask is not None:
attention_mask = attention_mask.masked_fill(prune_mask, float("-inf"))
else:
attention_mask = torch.zeros_like(F_mask)
attention_mask = attention_mask.masked_fill(prune_mask, float("-inf"))

hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# Linear projection
hidden_states = attn.to_out[0](hidden_states)
# Dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


class XLAFlashAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
Expand Down
67 changes: 67 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,73 @@ def disable_xformers_memory_efficient_attention(self) -> None:
"""
self.set_use_memory_efficient_attention_xformers(False)

def enable_selective_attention(
self,
selection_head_idx: int = 0,
bidirectional: bool = True,
enable_pruning: bool = False,
pruning_threshold: float = 1.0,
masking_strength: float = 0.1,
) -> None:
r"""
Enable selective attention for memory-efficient inference.

Based on "Selective Attention Improves Transformer" (https://arxiv.org/abs/2410.02703).

Parameters:
selection_head_idx (`int`, defaults to 0):
Index of the attention head to use for computing selection scores.
bidirectional (`bool`, defaults to `True`):
Whether to use bidirectional masking (for diffusion) or causal masking (for LLMs).
enable_pruning (`bool`, defaults to `False`):
Whether to remove heavily masked tokens from computation.
pruning_threshold (`float`, defaults to 1.0):
Tokens with masking scores above this threshold are pruned.
masking_strength (`float`, defaults to 0.1):
Scaling factor for masking values. Lower = gentler masking. Start with 0.1 and increase.
"""
from .attention_processor import AttnProcessor2_0, MemoryEfficientSelectiveAttnProcessor2_0

def fn_recursive_set_selective_attention(module: torch.nn.Module):
if hasattr(module, "set_processor"):
current_processor = module.get_processor() if hasattr(module, "get_processor") else None
if current_processor is None or isinstance(current_processor, AttnProcessor2_0):
module.set_processor(
MemoryEfficientSelectiveAttnProcessor2_0(
selection_head_idx=selection_head_idx,
bidirectional=bidirectional,
enable_pruning=enable_pruning,
pruning_threshold=pruning_threshold,
masking_strength=masking_strength,
)
)

for child in module.children():
fn_recursive_set_selective_attention(child)

for module in self.children():
if isinstance(module, torch.nn.Module):
fn_recursive_set_selective_attention(module)

def disable_selective_attention(self) -> None:
r"""
Disable selective attention and revert to standard attention.
"""
from .attention_processor import AttnProcessor2_0, MemoryEfficientSelectiveAttnProcessor2_0

def fn_recursive_disable_selective_attention(module: torch.nn.Module):
if hasattr(module, "set_processor"):
current_processor = module.get_processor() if hasattr(module, "get_processor") else None
if isinstance(current_processor, MemoryEfficientSelectiveAttnProcessor2_0):
module.set_processor(AttnProcessor2_0())

for child in module.children():
fn_recursive_disable_selective_attention(child)

for module in self.children():
if isinstance(module, torch.nn.Module):
fn_recursive_disable_selective_attention(module)

def enable_layerwise_casting(
self,
storage_dtype: torch.dtype = torch.float8_e4m3fn,
Expand Down