diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 66455d733aee..91c31fe6dcf1 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2785,6 +2785,227 @@ def __call__( return hidden_states +class MemoryEfficientSelectiveAttnProcessor2_0: + r""" + Processor for implementing selective attention (https://arxiv.org/abs/2410.02703) with optional token pruning. + + Args: + selection_head_idx (`int`, defaults to 0): + Index of the attention head to use for computing selection scores. + bidirectional (`bool`, defaults to `True`): + Whether to use bidirectional masking (for diffusion) or causal masking (for LLMs). + enable_pruning (`bool`, defaults to `False`): + Whether to remove heavily masked tokens from computation. + pruning_threshold (`float`, defaults to 1.0): + Tokens with masking scores above this threshold are pruned. + masking_strength (`float`, defaults to 0.1): + Scaling factor for the masking values. Lower values = gentler masking. + Set to 0.0 to disable selective masking entirely (useful for debugging). + """ + + def __init__( + self, + selection_head_idx: int = 0, + bidirectional: bool = True, + enable_pruning: bool = False, + pruning_threshold: float = 1.0, + masking_strength: float = 0.1, + ): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "MemoryEfficientSelectiveAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + self.selection_head_idx = selection_head_idx + self.bidirectional = bidirectional + self.enable_pruning = enable_pruning + self.pruning_threshold = pruning_threshold + self.masking_strength = masking_strength + + def _compute_selection_scores( + self, + query: torch.Tensor, + key: torch.Tensor, + selection_head_idx: int, + ) -> torch.Tensor: + """Compute selection scores S from the specified attention head.""" + seq_len = query.shape[2] + head_dim = query.shape[3] + + q_select = query[:, selection_head_idx, :, :] + k_select = key[:, selection_head_idx, :, :] + + scale = head_dim ** -0.5 + selection_logits = torch.matmul(q_select, k_select.transpose(-2, -1)) * scale + + # Apply constraints: ReLU, zero first column, zero diagonal + S = torch.relu(selection_logits) + S[:, :, 0] = 0 + diagonal_mask = torch.eye(seq_len, device=S.device, dtype=torch.bool) + S = S.masked_fill(diagonal_mask.unsqueeze(0), 0) + + return S + + def _compute_masking_matrix( + self, + S: torch.Tensor, + bidirectional: bool = True, + ) -> torch.Tensor: + """Compute the accumulated masking matrix F from selection scores S.""" + if bidirectional: + # Symmetric: F[i,j] = sum_{k!=i} S[k,j] + total_votes = S.sum(dim=-2, keepdim=True) + F = total_votes - S + else: + # Causal: F[i,j] = sum_{k torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = ( + "The `scale` argument is deprecated and will be ignored. Please remove it, " + "as passing it will raise an error in the future. `scale` should directly be " + "passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + ) + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + is_self_attention = encoder_hidden_states is None + if is_self_attention: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Selective attention: only apply to self-attention with at least 2 tokens + if is_self_attention and query.shape[2] > 1: + if self.selection_head_idx >= attn.heads: + logger.warning( + f"selection_head_idx ({self.selection_head_idx}) >= num_heads ({attn.heads}). " + f"Falling back to head 0." + ) + selection_head_idx = 0 + else: + selection_head_idx = self.selection_head_idx + + # Compute selection scores and masking matrix + S = self._compute_selection_scores(query, key, selection_head_idx) + F_mask = self._compute_masking_matrix(S, self.bidirectional) + F_mask = F_mask.unsqueeze(1).expand(-1, attn.heads, -1, -1) + + # Normalize F_mask to prevent explosion for long sequences + # Scale by 1/seq_len to make masking strength sequence-length independent + seq_len = query.shape[2] + F_mask = F_mask / seq_len + + # Apply masking strength scaling + # Lower strength = gentler masking, preserving more of the original attention + F_mask = F_mask * self.masking_strength + + # Combine with existing attention mask (F is subtracted from logits) + selective_mask = -F_mask + if attention_mask is not None: + attention_mask = attention_mask + selective_mask + else: + attention_mask = selective_mask + + # Optional pruning: remove heavily masked tokens + if self.enable_pruning: + max_masking_per_token = F_mask[:, 0, :, :].max(dim=-2)[0] + keep_mask = max_masking_per_token < self.pruning_threshold + + # Ensure at least 10% of tokens are kept + min_tokens = max(1, int(sequence_length * 0.1)) + tokens_to_keep = keep_mask.sum(dim=-1) + + for b in range(batch_size): + if tokens_to_keep[b] < min_tokens: + _, indices = max_masking_per_token[b].topk(min_tokens, largest=False) + keep_mask[b] = False + keep_mask[b, indices] = True + + # Apply pruning mask + if not keep_mask.all(): + prune_mask = ~keep_mask + prune_mask = prune_mask.unsqueeze(1).unsqueeze(2) + prune_mask = prune_mask.expand(-1, attn.heads, query.shape[2], -1) + + if attention_mask is not None: + attention_mask = attention_mask.masked_fill(prune_mask, float("-inf")) + else: + attention_mask = torch.zeros_like(F_mask) + attention_mask = attention_mask.masked_fill(prune_mask, float("-inf")) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Linear projection + hidden_states = attn.to_out[0](hidden_states) + # Dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class XLAFlashAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`. diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 41da95d3a2a2..8349a3e1bd05 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -435,6 +435,73 @@ def disable_xformers_memory_efficient_attention(self) -> None: """ self.set_use_memory_efficient_attention_xformers(False) + def enable_selective_attention( + self, + selection_head_idx: int = 0, + bidirectional: bool = True, + enable_pruning: bool = False, + pruning_threshold: float = 1.0, + masking_strength: float = 0.1, + ) -> None: + r""" + Enable selective attention for memory-efficient inference. + + Based on "Selective Attention Improves Transformer" (https://arxiv.org/abs/2410.02703). + + Parameters: + selection_head_idx (`int`, defaults to 0): + Index of the attention head to use for computing selection scores. + bidirectional (`bool`, defaults to `True`): + Whether to use bidirectional masking (for diffusion) or causal masking (for LLMs). + enable_pruning (`bool`, defaults to `False`): + Whether to remove heavily masked tokens from computation. + pruning_threshold (`float`, defaults to 1.0): + Tokens with masking scores above this threshold are pruned. + masking_strength (`float`, defaults to 0.1): + Scaling factor for masking values. Lower = gentler masking. Start with 0.1 and increase. + """ + from .attention_processor import AttnProcessor2_0, MemoryEfficientSelectiveAttnProcessor2_0 + + def fn_recursive_set_selective_attention(module: torch.nn.Module): + if hasattr(module, "set_processor"): + current_processor = module.get_processor() if hasattr(module, "get_processor") else None + if current_processor is None or isinstance(current_processor, AttnProcessor2_0): + module.set_processor( + MemoryEfficientSelectiveAttnProcessor2_0( + selection_head_idx=selection_head_idx, + bidirectional=bidirectional, + enable_pruning=enable_pruning, + pruning_threshold=pruning_threshold, + masking_strength=masking_strength, + ) + ) + + for child in module.children(): + fn_recursive_set_selective_attention(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_selective_attention(module) + + def disable_selective_attention(self) -> None: + r""" + Disable selective attention and revert to standard attention. + """ + from .attention_processor import AttnProcessor2_0, MemoryEfficientSelectiveAttnProcessor2_0 + + def fn_recursive_disable_selective_attention(module: torch.nn.Module): + if hasattr(module, "set_processor"): + current_processor = module.get_processor() if hasattr(module, "get_processor") else None + if isinstance(current_processor, MemoryEfficientSelectiveAttnProcessor2_0): + module.set_processor(AttnProcessor2_0()) + + for child in module.children(): + fn_recursive_disable_selective_attention(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_disable_selective_attention(module) + def enable_layerwise_casting( self, storage_dtype: torch.dtype = torch.float8_e4m3fn,