From d8a0025de9c8ac791f4ccd5e37bea9b63b8240c8 Mon Sep 17 00:00:00 2001 From: cdutr Date: Mon, 15 Dec 2025 15:47:47 -0300 Subject: [PATCH] Adds memory-efficient selective attention processor Implements a new attention processing technique based on the "Selective Attention Improves Transformer" paper, enabling more efficient and flexible attention mechanisms. Key features: - Introduces MemoryEfficientSelectiveAttnProcessor2_0 for advanced attention masking - Supports configurable token selection and optional pruning - Adds methods to enable/disable selective attention across model modules - Provides fine-grained control over masking strength and token selection Enables more intelligent and computationally efficient attention by allowing selective token interaction and pruning, which can improve model performance and reduce computational overhead. --- src/diffusers/models/attention_processor.py | 221 ++++++++++++++++++++ src/diffusers/models/modeling_utils.py | 67 ++++++ 2 files changed, 288 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 66455d733aee..91c31fe6dcf1 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2785,6 +2785,227 @@ def __call__( return hidden_states +class MemoryEfficientSelectiveAttnProcessor2_0: + r""" + Processor for implementing selective attention (https://arxiv.org/abs/2410.02703) with optional token pruning. + + Args: + selection_head_idx (`int`, defaults to 0): + Index of the attention head to use for computing selection scores. + bidirectional (`bool`, defaults to `True`): + Whether to use bidirectional masking (for diffusion) or causal masking (for LLMs). + enable_pruning (`bool`, defaults to `False`): + Whether to remove heavily masked tokens from computation. + pruning_threshold (`float`, defaults to 1.0): + Tokens with masking scores above this threshold are pruned. + masking_strength (`float`, defaults to 0.1): + Scaling factor for the masking values. Lower values = gentler masking. + Set to 0.0 to disable selective masking entirely (useful for debugging). + """ + + def __init__( + self, + selection_head_idx: int = 0, + bidirectional: bool = True, + enable_pruning: bool = False, + pruning_threshold: float = 1.0, + masking_strength: float = 0.1, + ): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "MemoryEfficientSelectiveAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + self.selection_head_idx = selection_head_idx + self.bidirectional = bidirectional + self.enable_pruning = enable_pruning + self.pruning_threshold = pruning_threshold + self.masking_strength = masking_strength + + def _compute_selection_scores( + self, + query: torch.Tensor, + key: torch.Tensor, + selection_head_idx: int, + ) -> torch.Tensor: + """Compute selection scores S from the specified attention head.""" + seq_len = query.shape[2] + head_dim = query.shape[3] + + q_select = query[:, selection_head_idx, :, :] + k_select = key[:, selection_head_idx, :, :] + + scale = head_dim ** -0.5 + selection_logits = torch.matmul(q_select, k_select.transpose(-2, -1)) * scale + + # Apply constraints: ReLU, zero first column, zero diagonal + S = torch.relu(selection_logits) + S[:, :, 0] = 0 + diagonal_mask = torch.eye(seq_len, device=S.device, dtype=torch.bool) + S = S.masked_fill(diagonal_mask.unsqueeze(0), 0) + + return S + + def _compute_masking_matrix( + self, + S: torch.Tensor, + bidirectional: bool = True, + ) -> torch.Tensor: + """Compute the accumulated masking matrix F from selection scores S.""" + if bidirectional: + # Symmetric: F[i,j] = sum_{k!=i} S[k,j] + total_votes = S.sum(dim=-2, keepdim=True) + F = total_votes - S + else: + # Causal: F[i,j] = sum_{k torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = ( + "The `scale` argument is deprecated and will be ignored. Please remove it, " + "as passing it will raise an error in the future. `scale` should directly be " + "passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + ) + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + is_self_attention = encoder_hidden_states is None + if is_self_attention: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Selective attention: only apply to self-attention with at least 2 tokens + if is_self_attention and query.shape[2] > 1: + if self.selection_head_idx >= attn.heads: + logger.warning( + f"selection_head_idx ({self.selection_head_idx}) >= num_heads ({attn.heads}). " + f"Falling back to head 0." + ) + selection_head_idx = 0 + else: + selection_head_idx = self.selection_head_idx + + # Compute selection scores and masking matrix + S = self._compute_selection_scores(query, key, selection_head_idx) + F_mask = self._compute_masking_matrix(S, self.bidirectional) + F_mask = F_mask.unsqueeze(1).expand(-1, attn.heads, -1, -1) + + # Normalize F_mask to prevent explosion for long sequences + # Scale by 1/seq_len to make masking strength sequence-length independent + seq_len = query.shape[2] + F_mask = F_mask / seq_len + + # Apply masking strength scaling + # Lower strength = gentler masking, preserving more of the original attention + F_mask = F_mask * self.masking_strength + + # Combine with existing attention mask (F is subtracted from logits) + selective_mask = -F_mask + if attention_mask is not None: + attention_mask = attention_mask + selective_mask + else: + attention_mask = selective_mask + + # Optional pruning: remove heavily masked tokens + if self.enable_pruning: + max_masking_per_token = F_mask[:, 0, :, :].max(dim=-2)[0] + keep_mask = max_masking_per_token < self.pruning_threshold + + # Ensure at least 10% of tokens are kept + min_tokens = max(1, int(sequence_length * 0.1)) + tokens_to_keep = keep_mask.sum(dim=-1) + + for b in range(batch_size): + if tokens_to_keep[b] < min_tokens: + _, indices = max_masking_per_token[b].topk(min_tokens, largest=False) + keep_mask[b] = False + keep_mask[b, indices] = True + + # Apply pruning mask + if not keep_mask.all(): + prune_mask = ~keep_mask + prune_mask = prune_mask.unsqueeze(1).unsqueeze(2) + prune_mask = prune_mask.expand(-1, attn.heads, query.shape[2], -1) + + if attention_mask is not None: + attention_mask = attention_mask.masked_fill(prune_mask, float("-inf")) + else: + attention_mask = torch.zeros_like(F_mask) + attention_mask = attention_mask.masked_fill(prune_mask, float("-inf")) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Linear projection + hidden_states = attn.to_out[0](hidden_states) + # Dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class XLAFlashAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`. diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 41da95d3a2a2..8349a3e1bd05 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -435,6 +435,73 @@ def disable_xformers_memory_efficient_attention(self) -> None: """ self.set_use_memory_efficient_attention_xformers(False) + def enable_selective_attention( + self, + selection_head_idx: int = 0, + bidirectional: bool = True, + enable_pruning: bool = False, + pruning_threshold: float = 1.0, + masking_strength: float = 0.1, + ) -> None: + r""" + Enable selective attention for memory-efficient inference. + + Based on "Selective Attention Improves Transformer" (https://arxiv.org/abs/2410.02703). + + Parameters: + selection_head_idx (`int`, defaults to 0): + Index of the attention head to use for computing selection scores. + bidirectional (`bool`, defaults to `True`): + Whether to use bidirectional masking (for diffusion) or causal masking (for LLMs). + enable_pruning (`bool`, defaults to `False`): + Whether to remove heavily masked tokens from computation. + pruning_threshold (`float`, defaults to 1.0): + Tokens with masking scores above this threshold are pruned. + masking_strength (`float`, defaults to 0.1): + Scaling factor for masking values. Lower = gentler masking. Start with 0.1 and increase. + """ + from .attention_processor import AttnProcessor2_0, MemoryEfficientSelectiveAttnProcessor2_0 + + def fn_recursive_set_selective_attention(module: torch.nn.Module): + if hasattr(module, "set_processor"): + current_processor = module.get_processor() if hasattr(module, "get_processor") else None + if current_processor is None or isinstance(current_processor, AttnProcessor2_0): + module.set_processor( + MemoryEfficientSelectiveAttnProcessor2_0( + selection_head_idx=selection_head_idx, + bidirectional=bidirectional, + enable_pruning=enable_pruning, + pruning_threshold=pruning_threshold, + masking_strength=masking_strength, + ) + ) + + for child in module.children(): + fn_recursive_set_selective_attention(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_selective_attention(module) + + def disable_selective_attention(self) -> None: + r""" + Disable selective attention and revert to standard attention. + """ + from .attention_processor import AttnProcessor2_0, MemoryEfficientSelectiveAttnProcessor2_0 + + def fn_recursive_disable_selective_attention(module: torch.nn.Module): + if hasattr(module, "set_processor"): + current_processor = module.get_processor() if hasattr(module, "get_processor") else None + if isinstance(current_processor, MemoryEfficientSelectiveAttnProcessor2_0): + module.set_processor(AttnProcessor2_0()) + + for child in module.children(): + fn_recursive_disable_selective_attention(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_disable_selective_attention(module) + def enable_layerwise_casting( self, storage_dtype: torch.dtype = torch.float8_e4m3fn,