From 4a475bb5ed607571fe4cf50c3f9d36a23b86f456 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=94=AF=E5=8B=A4?= Date: Fri, 7 Nov 2025 18:07:46 +0800 Subject: [PATCH 1/9] opt beam search & add fa ing... --- fireredasr/models/fireredasr_aed.py | 1 + .../models/module/transformer_decoder.py | 320 ++++++++++++++---- 2 files changed, 248 insertions(+), 73 deletions(-) diff --git a/fireredasr/models/fireredasr_aed.py b/fireredasr/models/fireredasr_aed.py index 4938c2c..a3ab3b3 100644 --- a/fireredasr/models/fireredasr_aed.py +++ b/fireredasr/models/fireredasr_aed.py @@ -4,6 +4,7 @@ from fireredasr.models.module.transformer_decoder import TransformerDecoder +@torch.compile(mode="max-autotune") class FireRedAsrAed(torch.nn.Module): @classmethod def from_args(cls, args): diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index 2088b08..eb4a4d2 100644 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -4,6 +4,18 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor +import math +import os +import torch +import torch.nn as nn +import xformers.ops as xops + +from flash_attn import flash_attn_func, flash_attn_varlen_func + + +ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "SDPA") # Option: "NATIVE", "SDPA", "XFORMERS" +MultiHeadAttention = None +print("ATTENTION_BACKEND: ", ATTENTION_BACKEND) class TransformerDecoder(nn.Module): @@ -22,7 +34,6 @@ def __init__( # Components self.tgt_word_emb = nn.Embedding(odim, d_model, padding_idx=self.pad_id) self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen) - self.dropout = nn.Dropout(residual_dropout) self.layer_stack = nn.ModuleList() for l in range(n_layers): @@ -34,85 +45,126 @@ def __init__( self.tgt_word_prj.weight = self.tgt_word_emb.weight self.scale = (d_model ** 0.5) + self.scores_map = {} + self.stride_map = {} + self.filter_indexes = {} + self.active_masks = {} + def batch_beam_search(self, encoder_outputs, src_masks, beam_size=1, nbest=1, decode_max_len=0, softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0): B = beam_size N, Ti, H = encoder_outputs.size() + M = N * B device = encoder_outputs.device maxlen = decode_max_len if decode_max_len > 0 else Ti assert eos_penalty > 0.0 and eos_penalty <= 1.0 # Init - encoder_outputs = encoder_outputs.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, Ti, H) - src_mask = src_masks.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, -1, Ti) - ys = torch.ones(N*B, 1).fill_(self.sos_id).long().to(device) + encoder_outputs = encoder_outputs.unsqueeze(1).repeat(1, B, 1, 1).view(M, Ti, H) + src_mask = src_masks.unsqueeze(1).repeat(1, B, 1, 1).view(M, -1, Ti) + ys = torch.ones(M, 1, device=device).fill_(self.sos_id).long() caches: List[Optional[Tensor]] = [] for _ in range(self.n_layers): - caches.append(None) - scores = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(device) - scores = scores.repeat(N).view(N*B, 1) + caches.append(torch.empty(M, 0, H, device=device, dtype=encoder_outputs.dtype)) + + if B not in self.scores_map: + scores = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(device) + self.scores_map[B] = scores + scores = self.scores_map[B].repeat(N).view(M, 1) + finished_mask_score = self.scores_map[B] is_finished = torch.zeros_like(scores) + + if (B, N) not in self.stride_map: + stride = B * torch.arange(N).view(N, 1).repeat(1, B).view(M).to(device).long() + filter_index = torch.arange(M, device=device, dtype=torch.int32) + active_mask = torch.ones(M, dtype=torch.bool, device=device) + self.stride_map[(B, N)] = stride + self.filter_indexes[M] = filter_index + self.active_masks[M] = active_mask + stride = self.stride_map[(B, N)] + active_mask = self.active_masks[M] + active_indices = self.filter_indexes[M] + last_t_logit = None # Autoregressive Prediction for t in range(maxlen): tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id) - - dec_output = self.dropout( - self.tgt_word_emb(ys) * self.scale + - self.positional_encoding(ys)) - - i = 0 - for dec_layer in self.layer_stack: + dec_output = self.tgt_word_emb(ys) * self.scale + self.positional_encoding(ys) + + def expand(f, mask, indices, value=0.0, t=None): + if t is None: + t = torch.full((len(mask), *list(f.shape[1:])), value, dtype=f.dtype, device=f.device) + t[indices] = f + return t + + dec_output = dec_output[active_indices] + t_encoder_outputs = encoder_outputs[active_indices] + tgt_mask = tgt_mask[active_indices] + t_src_mask = src_mask[active_indices] + + for i, dec_layer in enumerate(self.layer_stack): dec_output = dec_layer.forward( - dec_output, encoder_outputs, - tgt_mask, src_mask, - cache=caches[i]) - caches[i] = dec_output - i += 1 + dec_output, + t_encoder_outputs, + tgt_mask, + t_src_mask, + cache=caches[i][active_indices]) + caches[i] = dec_output dec_output = self.layer_norm_out(dec_output) t_logit = self.tgt_word_prj(dec_output[:, -1]) + if last_t_logit is None: + last_t_logit = t_logit + else: + last_t_logit[active_indices] = t_logit + t_logit = last_t_logit t_scores = F.log_softmax(t_logit / softmax_smoothing, dim=-1) - + if eos_penalty != 1.0: t_scores[:, self.eos_id] *= eos_penalty t_topB_scores, t_topB_ys = torch.topk(t_scores, k=B, dim=1) - t_topB_scores = self.set_finished_beam_score_to_zero(t_topB_scores, is_finished) + t_topB_scores = self.set_finished_beam_score_to_zero(t_topB_scores, is_finished, mask_score=finished_mask_score) t_topB_ys = self.set_finished_beam_y_to_eos(t_topB_ys, is_finished) # Accumulated scores = scores + t_topB_scores - # Pruning + # Pruning scores = scores.view(N, B*B) scores, topB_score_ids = torch.topk(scores, k=B, dim=1) scores = scores.view(-1, 1) - - topB_row_number_in_each_B_rows_of_ys = torch.div(topB_score_ids, B).view(N*B) - stride = B * torch.arange(N).view(N, 1).repeat(1, B).view(N*B).to(device) - topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long() + + topB_row_number_in_each_B_rows_of_ys = torch.div(topB_score_ids, B).view(M) + topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride # Update ys ys = ys[topB_row_number_in_ys] - t_ys = torch.gather(t_topB_ys.view(N, B*B), dim=1, index=topB_score_ids).view(N*B, 1) + t_ys = torch.gather(t_topB_ys.view(N, B*B), dim=1, index=topB_score_ids).view(M, 1) ys = torch.cat((ys, t_ys), dim=1) # Update caches new_caches: List[Optional[Tensor]] = [] + target = torch.full((len(active_mask), *list(caches[0].shape[1:])), + 0.0, + dtype=caches[0].dtype, + device=caches[0].device) for cache in caches: - if cache is not None: - new_caches.append(cache[topB_row_number_in_ys]) + new_caches.append(expand(cache, active_mask, active_indices,t=target)[topB_row_number_in_ys]) caches = new_caches # Update finished state is_finished = t_ys.eq(self.eos_id) - if is_finished.sum().item() == N*B: - break + is_finished_n = is_finished.sum().item() + active_mask = ~is_finished.squeeze() + active_indices = torch.nonzero_static(active_mask, size=M - int(is_finished_n)).squeeze(1) + if is_finished_n == M: + break + # Length penalty (follow GNMT) scores = scores.view(N, B) ys = ys.view(N, B, -1) @@ -123,28 +175,22 @@ def batch_beam_search(self, encoder_outputs, src_masks, nbest_scores, nbest_ids = torch.topk(scores, k=int(nbest), dim=1) nbest_scores = -1.0 * nbest_scores index = nbest_ids + B * torch.arange(N).view(N, 1).to(device).long() - nbest_ys = ys.view(N*B, -1)[index.view(-1)] - nbest_ys = nbest_ys.view(N, nbest_ids.size(1), -1) - nbest_ys_lengths = ys_lengths.view(N*B)[index.view(-1)].view(N, -1) - - # result - nbest_hyps: List[List[Dict[str, Tensor]]] = [] - for n in range(N): - n_nbest_hyps: List[Dict[str, Tensor]] = [] - for i, score in enumerate(nbest_scores[n]): - new_hyp = { - "yseq": nbest_ys[n, i, 1:nbest_ys_lengths[n, i]] - } - n_nbest_hyps.append(new_hyp) - nbest_hyps.append(n_nbest_hyps) - return nbest_hyps + nbest_ys = ys.view(M, -1)[index.view(-1)].view(N, nbest_ids.size(1), -1) + nbest_ys_lengths = ys_lengths.view(M)[index.view(-1)].view(N, -1) + + return [ + [ + {"yseq": nbest_ys[n, i, 1:nbest_ys_lengths[n, i]]} + for i, _ in enumerate(nbest_scores[n]) + ] + for n in range(N) + ] def ignored_target_position_is_0(self, padded_targets, ignore_id): mask = torch.ne(padded_targets, ignore_id) mask = mask.unsqueeze(dim=1) T = padded_targets.size(-1) - upper_tri_0_mask = self.upper_triangular_is_0(T).unsqueeze(0).to(mask.dtype) - upper_tri_0_mask = upper_tri_0_mask.to(mask.dtype).to(mask.device) + upper_tri_0_mask = self.upper_triangular_is_0(T).unsqueeze(0).to(mask.dtype).to(mask.device) return mask.to(torch.uint8) & upper_tri_0_mask.to(torch.uint8) def upper_triangular_is_0(self, size): @@ -152,10 +198,9 @@ def upper_triangular_is_0(self, size): tri_left_ones = torch.tril(ones) return tri_left_ones.to(torch.uint8) - def set_finished_beam_score_to_zero(self, scores, is_finished): + def set_finished_beam_score_to_zero(self, scores, is_finished, mask_score): NB, B = scores.size() is_finished = is_finished.float() - mask_score = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(scores.device) mask_score = mask_score.view(1, B).repeat(NB, 1) return scores * (1 - is_finished) + mask_score * is_finished @@ -174,10 +219,21 @@ class DecoderLayer(nn.Module): def __init__(self, d_model, n_head, dropout): super().__init__() self.self_attn_norm = nn.LayerNorm(d_model) - self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout) + if ATTENTION_BACKEND.upper() == "NATIVE": + MultiHeadAttention = DecoderMultiHeadAttention + elif ATTENTION_BACKEND.upper() == "SDPA": + MultiHeadAttention = DecoderMHATorchSDPA + elif ATTENTION_BACKEND.upper() == "XFORMERS": + MultiHeadAttention = DecoderMHAXFormers + elif ATTENTION_BACKEND.upper() == "FLASH_ATTN": + MultiHeadAttention = DecoderMHAFlashAttn + else: + print("Unsupported attention backend: ", ATTENTION_BACKEND) + exit(1) + self.self_attn = MultiHeadAttention(d_model, n_head, dropout) self.cross_attn_norm = nn.LayerNorm(d_model) - self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout) + self.cross_attn = MultiHeadAttention(d_model, n_head, dropout) self.mlp_norm = nn.LayerNorm(d_model) self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout) @@ -187,30 +243,24 @@ def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, x = dec_input residual = x x = self.self_attn_norm(x) - if cache is not None: + if cache.shape[1]: xq = x[:, -1:, :] residual = residual[:, -1:, :] self_attn_mask = self_attn_mask[:, -1:, :] else: xq = x - x = self.self_attn(xq, x, x, mask=self_attn_mask) - x = residual + x - + x = residual + self.self_attn(xq, x, x, mask=self_attn_mask) residual = x - x = self.cross_attn_norm(x) - x = self.cross_attn(x, enc_output, enc_output, mask=cross_attn_mask) - x = residual + x + x = residual+self.cross_attn(self.cross_attn_norm(x), enc_output, enc_output, mask=cross_attn_mask, is_cross=True) residual = x x = self.mlp_norm(x) x = residual + self.mlp(x) - if cache is not None: - x = torch.cat([cache, x], dim=1) - + x = torch.cat([cache, x], dim=1) return x - +# Native MHA class DecoderMultiHeadAttention(nn.Module): def __init__(self, d_model, n_head, dropout=0.1): super().__init__() @@ -221,13 +271,10 @@ def __init__(self, d_model, n_head, dropout=0.1): self.w_qs = nn.Linear(d_model, n_head * self.d_k) self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False) self.w_vs = nn.Linear(d_model, n_head * self.d_k) - - self.attention = DecoderScaledDotProductAttention( - temperature=self.d_k ** 0.5) + self.attention = DecoderScaledDotProductAttention(temperature=self.d_k ** 0.5) self.fc = nn.Linear(n_head * self.d_k, d_model) - self.dropout = nn.Dropout(dropout) - def forward(self, q, k, v, mask=None): + def forward(self, q, k, v, mask=None, is_cross=False): bs = q.size(0) q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k) @@ -244,11 +291,9 @@ def forward(self, q, k, v, mask=None): output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model) output = self.fc(output) - output = self.dropout(output) return output - class DecoderScaledDotProductAttention(nn.Module): def __init__(self, temperature): super().__init__() @@ -266,21 +311,150 @@ def forward(self, q, k, v, mask=None): output = torch.matmul(attn, v) return output +# MHA with Torch SDPA +class DecoderMHATorchSDPA(nn.Module): + def __init__(self, d_model, n_head, dropout=0.1): + super().__init__() + self.d_model = d_model + self.n_head = n_head + self.d_k = d_model // n_head + + self.w_qs = nn.Linear(d_model, n_head * self.d_k) + self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False) + self.w_vs = nn.Linear(d_model, n_head * self.d_k) + self.attention = DecoderTorchSDPA(temperature=self.d_k ** 0.5) + self.fc = nn.Linear(n_head * self.d_k, d_model) + + def forward(self, q, k, v, mask=None, is_cross=False): + bs = q.size(0) + + q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k).transpose(1, 2) + k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k).transpose(1, 2) + v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k).transpose(1, 2) + + output = self.attention(q, k, v, mask=mask.unsqueeze(1)) + output = self.fc(output.transpose(1, 2).contiguous().view(bs, -1, self.d_model)) + return output + +class DecoderTorchSDPA(nn.Module): + def __init__(self, temperature): + super().__init__() + self.temperature = temperature + + def forward(self, q, k, v, mask=None): + """ + q, k, v: (batch, num_heads, seq_len, d_k) + mask: optional attention mask + - If boolean: shape (batch, 1, seq_len, seq_len) or broadcastable. + True means 'mask out'. + - If float: same shape, with -inf for masked positions. + """ + # F.scaled_dot_product_attention will: + # - scale internally + # - apply softmax + # - apply mask if given + # - compute attention output + output = F.scaled_dot_product_attention( + q, k, v, + attn_mask=mask.eq(1), + dropout_p=0.0, # set >0 only during training + is_causal=False, # set True to get causal masking automatically + ) + return output + +# MHA with xFormers +class DecoderMHAXFormers(nn.Module): + def __init__(self, d_model, n_head, dropout=0.1): + super().__init__() + assert d_model % n_head == 0 + self.d_model = d_model + self.n_head = n_head + self.d_k = d_model // n_head + + self.w_qs = nn.Linear(d_model, n_head * self.d_k) + self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False) + self.w_vs = nn.Linear(d_model, n_head * self.d_k) + + self.fc = nn.Linear(n_head * self.d_k, d_model) + + def forward(self, q, k, v, mask=None, is_cross=False): + bs = q.size(0) + + # projection and transform to (batch*n_head, seq_len, head_dim) + q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k).transpose(1, 2).reshape(bs * self.n_head, -1, self.d_k) + k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k).transpose(1, 2).reshape(bs * self.n_head, -1, self.d_k) + v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k).transpose(1, 2).reshape(bs * self.n_head, -1, self.d_k) + # single-step reshape+transpose + output = xops.memory_efficient_attention(q, k, v) + # back to (bs, seq_len, d_model) + output = output.reshape(bs, self.n_head, -1, self.d_k).transpose(1, 2).contiguous().view(bs, -1, self.d_model) + output = self.fc(output) + return output + + +class DecoderMHAFlashAttn(nn.Module): + def __init__(self, d_model, n_head, dropout=0.1): + super().__init__() + self.d_model = d_model + self.n_head = n_head + self.d_k = d_model // n_head + + self.w_qs = nn.Linear(d_model, n_head * self.d_k) + self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False) + self.w_vs = nn.Linear(d_model, n_head * self.d_k) + self.fc = nn.Linear(n_head * self.d_k, d_model) + #not support mask yet, only work well for batch=1 + def forward(self, q, k, v, mask=None, is_cross=False): + is_casual = not is_cross + bs = q.size(0) + + if not is_casual: + q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k) + k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) + v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) + output = flash_attn_func(q, k, v) + else: + q = self.w_qs(q).view(-1, self.n_head, self.d_k) + k = self.w_ks(k).view(-1, self.n_head, self.d_k) + v = self.w_vs(v).view(-1, self.n_head, self.d_k) + mask = mask.squeeze(1) + bool_mask = mask.view(-1).bool() + + if q.shape[0] == len(bool_mask): + var_q = q[bool_mask] + else: + var_q = q + var_k = k[bool_mask] + var_v = v[bool_mask] + + seq_lens = mask.sum(dim=1, dtype=torch.int32) + end_pos=torch.cumsum(seq_lens, dim=0) + start_pos=torch.tensor([0], device=seq_lens.device) + pos=torch.cat([start_pos, end_pos]).to(torch.int32) + cu_seqlens_q = pos + cu_seqlens_k = cu_seqlens_q + max_seqlen_q = max(seq_lens) + max_seqlen_k = max_seqlen_q + output = flash_attn_varlen_func(var_q, var_k, var_v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=is_casual) + output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model) + output = self.fc(output) + return output + + +# @torch.compile(mode="reduce-overhead", backend="inductor") class PositionwiseFeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): super().__init__() self.w_1 = nn.Linear(d_model, d_ff) self.act = nn.GELU() self.w_2 = nn.Linear(d_ff, d_model) - self.dropout = nn.Dropout(dropout) def forward(self, x): output = self.w_2(self.act(self.w_1(x))) - output = self.dropout(output) return output - +# @torch.compile(mode="reduce-overhead", backend="inductor") class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() From c052a4fbfd6889e84e93d4ae6261ed6c81cf06c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=94=AF=E5=8B=A4?= Date: Mon, 10 Nov 2025 18:55:53 +0800 Subject: [PATCH 2/9] fix flash attention --- .../models/module/transformer_decoder.py | 41 ++++++++++++------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index eb4a4d2..8db77ca 100644 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -1,4 +1,5 @@ from typing import List, Optional, Dict +import inspect import torch import torch.nn as nn @@ -12,6 +13,7 @@ from flash_attn import flash_attn_func, flash_attn_varlen_func +import einops ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "SDPA") # Option: "NATIVE", "SDPA", "XFORMERS" MultiHeadAttention = None @@ -160,6 +162,7 @@ def expand(f, mask, indices, value=0.0, t=None): is_finished = t_ys.eq(self.eos_id) is_finished_n = is_finished.sum().item() active_mask = ~is_finished.squeeze() + #active_indices = self.filter_indexes[M][active_mask] active_indices = torch.nonzero_static(active_mask, size=M - int(is_finished_n)).squeeze(1) if is_finished_n == M: @@ -385,9 +388,13 @@ def forward(self, q, k, v, mask=None, is_cross=False): k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k).transpose(1, 2).reshape(bs * self.n_head, -1, self.d_k) v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k).transpose(1, 2).reshape(bs * self.n_head, -1, self.d_k) # single-step reshape+transpose + #q = einops.rearrange(self.w_qs(q), 'b s (h d) -> (b h) s d', h=self.n_head) + #k = einops.rearrange(self.w_ks(k), 'b s (h d) -> (b h) s d', h=self.n_head) + #v = einops.rearrange(self.w_vs(v), 'b s (h d) -> (b h) s d', h=self.n_head) output = xops.memory_efficient_attention(q, k, v) # back to (bs, seq_len, d_model) output = output.reshape(bs, self.n_head, -1, self.d_k).transpose(1, 2).contiguous().view(bs, -1, self.d_model) + #output = einops.rearrange(output, '(b h) s d -> b s (h d)', b=bs, h=self.n_head) output = self.fc(output) return output @@ -404,16 +411,16 @@ def __init__(self, d_model, n_head, dropout=0.1): self.w_vs = nn.Linear(d_model, n_head * self.d_k) self.fc = nn.Linear(n_head * self.d_k, d_model) - #not support mask yet, only work well for batch=1 def forward(self, q, k, v, mask=None, is_cross=False): is_casual = not is_cross bs = q.size(0) - if not is_casual: + if is_casual: q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k) k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) - output = flash_attn_func(q, k, v) + output = flash_attn_func(q, k, v, causal=is_casual) + output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model) else: q = self.w_qs(q).view(-1, self.n_head, self.d_k) k = self.w_ks(k).view(-1, self.n_head, self.d_k) @@ -421,23 +428,27 @@ def forward(self, q, k, v, mask=None, is_cross=False): mask = mask.squeeze(1) bool_mask = mask.view(-1).bool() - if q.shape[0] == len(bool_mask): - var_q = q[bool_mask] - else: - var_q = q var_k = k[bool_mask] var_v = v[bool_mask] seq_lens = mask.sum(dim=1, dtype=torch.int32) - end_pos=torch.cumsum(seq_lens, dim=0) - start_pos=torch.tensor([0], device=seq_lens.device) + start_pos=torch.cumsum(seq_lens, dim=0) - seq_lens + end_pos=torch.tensor([torch.sum(seq_lens)], device=seq_lens.device) pos=torch.cat([start_pos, end_pos]).to(torch.int32) - cu_seqlens_q = pos - cu_seqlens_k = cu_seqlens_q - max_seqlen_q = max(seq_lens) - max_seqlen_k = max_seqlen_q - output = flash_attn_varlen_func(var_q, var_k, var_v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=is_casual) - output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model) + cu_seqlens_q = torch.arange(0, bs + 1, 1, device=seq_lens.device, dtype=torch.int32) + cu_seqlens_k = pos + max_seqlen_q = 1 + max_seqlen_k = max(seq_lens) + + output = flash_attn_varlen_func(q=q, + k=var_k, + v=var_v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + causal=is_casual) + output = output.contiguous().view(bs, -1, self.d_model) output = self.fc(output) return output From 7eaadfca9f580f9ea08c163d1d215b7110bd99c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=94=AF=E5=8B=A4?= Date: Tue, 11 Nov 2025 18:05:38 +0800 Subject: [PATCH 3/9] optimize flash attention --- .../models/module/transformer_decoder.py | 128 ++++++++++-------- 1 file changed, 69 insertions(+), 59 deletions(-) diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index 8db77ca..1790f4a 100644 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -52,6 +52,11 @@ def __init__( self.filter_indexes = {} self.active_masks = {} + + def clear(self): + for dec_layer in self.layer_stack: + dec_layer.clear() + def batch_beam_search(self, encoder_outputs, src_masks, beam_size=1, nbest=1, decode_max_len=0, @@ -112,7 +117,8 @@ def expand(f, mask, indices, value=0.0, t=None): t_encoder_outputs, tgt_mask, t_src_mask, - cache=caches[i][active_indices]) + cache=caches[i][active_indices], + active_indices=active_indices) caches[i] = dec_output dec_output = self.layer_norm_out(dec_output) @@ -240,9 +246,12 @@ def __init__(self, d_model, n_head, dropout): self.mlp_norm = nn.LayerNorm(d_model) self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout) + + def clear(self): + self.cross_attn.clear() def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, - cache=None): + cache=None, active_indices=None): x = dec_input residual = x x = self.self_attn_norm(x) @@ -254,7 +263,12 @@ def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, xq = x x = residual + self.self_attn(xq, x, x, mask=self_attn_mask) residual = x - x = residual+self.cross_attn(self.cross_attn_norm(x), enc_output, enc_output, mask=cross_attn_mask, is_cross=True) + x = residual+self.cross_attn(self.cross_attn_norm(x), + enc_output, + enc_output, + mask=cross_attn_mask, + is_cross=True, + active_indices=active_indices) residual = x x = self.mlp_norm(x) @@ -262,9 +276,8 @@ def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, x = torch.cat([cache, x], dim=1) return x - -# Native MHA -class DecoderMultiHeadAttention(nn.Module): + +class BaseMultiHeadAttention(nn.Module): def __init__(self, d_model, n_head, dropout=0.1): super().__init__() self.d_model = d_model @@ -274,10 +287,21 @@ def __init__(self, d_model, n_head, dropout=0.1): self.w_qs = nn.Linear(d_model, n_head * self.d_k) self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False) self.w_vs = nn.Linear(d_model, n_head * self.d_k) - self.attention = DecoderScaledDotProductAttention(temperature=self.d_k ** 0.5) self.fc = nn.Linear(n_head * self.d_k, d_model) - def forward(self, q, k, v, mask=None, is_cross=False): + def forward(self, q, k, v, mask=None, is_cross=False, active_indices=None): + raise NotImplementedError + + def clear(self): + pass + +# Native MHA +class DecoderMultiHeadAttention(BaseMultiHeadAttention): + def __init__(self, d_model, n_head, dropout=0.1): + super().__init__(d_model, n_head, dropout) + self.attention = DecoderScaledDotProductAttention(temperature=self.d_k ** 0.5) + + def forward(self, q, k, v, mask=None, is_cross=False, active_indices=None): bs = q.size(0) q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k) @@ -315,20 +339,12 @@ def forward(self, q, k, v, mask=None): return output # MHA with Torch SDPA -class DecoderMHATorchSDPA(nn.Module): +class DecoderMHATorchSDPA(BaseMultiHeadAttention): def __init__(self, d_model, n_head, dropout=0.1): - super().__init__() - self.d_model = d_model - self.n_head = n_head - self.d_k = d_model // n_head - - self.w_qs = nn.Linear(d_model, n_head * self.d_k) - self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False) - self.w_vs = nn.Linear(d_model, n_head * self.d_k) + super().__init__(d_model, n_head, dropout) self.attention = DecoderTorchSDPA(temperature=self.d_k ** 0.5) - self.fc = nn.Linear(n_head * self.d_k, d_model) - def forward(self, q, k, v, mask=None, is_cross=False): + def forward(self, q, k, v, mask=None, is_cross=False, active_indices=None): bs = q.size(0) q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k).transpose(1, 2) @@ -366,21 +382,11 @@ def forward(self, q, k, v, mask=None): return output # MHA with xFormers -class DecoderMHAXFormers(nn.Module): +class DecoderMHAXFormers(BaseMultiHeadAttention): def __init__(self, d_model, n_head, dropout=0.1): - super().__init__() - assert d_model % n_head == 0 - self.d_model = d_model - self.n_head = n_head - self.d_k = d_model // n_head - - self.w_qs = nn.Linear(d_model, n_head * self.d_k) - self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False) - self.w_vs = nn.Linear(d_model, n_head * self.d_k) + super().__init__(d_model, n_head, dropout) - self.fc = nn.Linear(n_head * self.d_k, d_model) - - def forward(self, q, k, v, mask=None, is_cross=False): + def forward(self, q, k, v, mask=None, is_cross=False, active_indices=None): bs = q.size(0) # projection and transform to (batch*n_head, seq_len, head_dim) @@ -399,19 +405,20 @@ def forward(self, q, k, v, mask=None, is_cross=False): return output -class DecoderMHAFlashAttn(nn.Module): +class DecoderMHAFlashAttn(BaseMultiHeadAttention): def __init__(self, d_model, n_head, dropout=0.1): - super().__init__() - self.d_model = d_model - self.n_head = n_head - self.d_k = d_model // n_head - - self.w_qs = nn.Linear(d_model, n_head * self.d_k) - self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False) - self.w_vs = nn.Linear(d_model, n_head * self.d_k) - self.fc = nn.Linear(n_head * self.d_k, d_model) + super().__init__(d_model, n_head, dropout) + + self.cross_cache_k = None + self.cross_cache_v = None + self.cross_cache_seqs = None + + def clear(self): + self.cross_cache_k = None + self.cross_cache_v = None + self.cross_cache_seqs = None - def forward(self, q, k, v, mask=None, is_cross=False): + def forward(self, q, k, v, mask=None, is_cross=False, active_indices=None): is_casual = not is_cross bs = q.size(0) @@ -422,31 +429,34 @@ def forward(self, q, k, v, mask=None, is_cross=False): output = flash_attn_func(q, k, v, causal=is_casual) output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model) else: - q = self.w_qs(q).view(-1, self.n_head, self.d_k) - k = self.w_ks(k).view(-1, self.n_head, self.d_k) - v = self.w_vs(v).view(-1, self.n_head, self.d_k) mask = mask.squeeze(1) bool_mask = mask.view(-1).bool() + q = self.w_qs(q).view(-1, self.n_head, self.d_k) - var_k = k[bool_mask] - var_v = v[bool_mask] - - seq_lens = mask.sum(dim=1, dtype=torch.int32) - start_pos=torch.cumsum(seq_lens, dim=0) - seq_lens - end_pos=torch.tensor([torch.sum(seq_lens)], device=seq_lens.device) - pos=torch.cat([start_pos, end_pos]).to(torch.int32) + if self.cross_cache_k is None or self.cross_cache_v is None: + k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) + v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) + + seq_lens = mask.sum(dim=1, dtype=torch.int32) + self.cross_cache_k = k + self.cross_cache_v = v + self.cross_cache_seqs = seq_lens + + seq_lens = self.cross_cache_seqs[active_indices] + total, _max_seq_k = seq_lens.sum().item(), seq_lens.max().item() + bool_indices = torch.nonzero_static(bool_mask, size=total).squeeze(1) + var_k = self.cross_cache_k[active_indices].view(-1, self.n_head, self.d_k)[bool_indices] + var_v = self.cross_cache_v[active_indices].view(-1, self.n_head, self.d_k)[bool_indices] cu_seqlens_q = torch.arange(0, bs + 1, 1, device=seq_lens.device, dtype=torch.int32) - cu_seqlens_k = pos - max_seqlen_q = 1 - max_seqlen_k = max(seq_lens) - + cu_seqlens_k = torch.zeros(bs + 1, device=seq_lens.device, dtype=torch.int32) + cu_seqlens_k[1:] = torch.cumsum(seq_lens, dim=0) output = flash_attn_varlen_func(q=q, k=var_k, v=var_v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, + max_seqlen_q=1, + max_seqlen_k=_max_seq_k, causal=is_casual) output = output.contiguous().view(bs, -1, self.d_model) output = self.fc(output) From e0e24b9d510a3b5c732b6c95a8a98071edbef48f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=94=AF=E5=8B=A4?= Date: Tue, 11 Nov 2025 19:56:51 +0800 Subject: [PATCH 4/9] for flash attention, casual mask is useless --- fireredasr/models/module/transformer_decoder.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index 1790f4a..3875753 100644 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -97,7 +97,11 @@ def batch_beam_search(self, encoder_outputs, src_masks, # Autoregressive Prediction for t in range(maxlen): - tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id) + if ATTENTION_BACKEND == "FLASH_ATTN": + tgt_mask = ys + else: + tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id) + dec_output = self.tgt_word_emb(ys) * self.scale + self.positional_encoding(ys) def expand(f, mask, indices, value=0.0, t=None): @@ -258,7 +262,7 @@ def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, if cache.shape[1]: xq = x[:, -1:, :] residual = residual[:, -1:, :] - self_attn_mask = self_attn_mask[:, -1:, :] + self_attn_mask = self.self_attn.parse_mask(self_attn_mask) else: xq = x x = residual + self.self_attn(xq, x, x, mask=self_attn_mask) @@ -294,6 +298,9 @@ def forward(self, q, k, v, mask=None, is_cross=False, active_indices=None): def clear(self): pass + + def parse_mask(self, mask): + return mask[:, -1:, :] # Native MHA class DecoderMultiHeadAttention(BaseMultiHeadAttention): @@ -441,6 +448,7 @@ def forward(self, q, k, v, mask=None, is_cross=False, active_indices=None): self.cross_cache_k = k self.cross_cache_v = v self.cross_cache_seqs = seq_lens + seq_lens = self.cross_cache_seqs[active_indices] total, _max_seq_k = seq_lens.sum().item(), seq_lens.max().item() @@ -461,6 +469,9 @@ def forward(self, q, k, v, mask=None, is_cross=False, active_indices=None): output = output.contiguous().view(bs, -1, self.d_model) output = self.fc(output) return output + + def parse_mask(self, mask): + return mask # @torch.compile(mode="reduce-overhead", backend="inductor") From 277303f0e62e666062cf30ef1bacf368ab9d5c6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=94=AF=E5=8B=A4?= Date: Wed, 12 Nov 2025 14:40:36 +0800 Subject: [PATCH 5/9] add missing commit --- fireredasr/models/fireredasr_aed.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fireredasr/models/fireredasr_aed.py b/fireredasr/models/fireredasr_aed.py index a3ab3b3..522a46d 100644 --- a/fireredasr/models/fireredasr_aed.py +++ b/fireredasr/models/fireredasr_aed.py @@ -29,6 +29,7 @@ def transcribe(self, padded_input, input_lengths, beam_size=1, nbest=1, decode_max_len=0, softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0): enc_outputs, _, enc_mask = self.encoder(padded_input, input_lengths) + self.decoder.clear() nbest_hyps = self.decoder.batch_beam_search( enc_outputs, enc_mask, beam_size, nbest, decode_max_len, From c2ae62bd574ebd30603d3bef0327bcf97db3d867 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=94=AF=E5=8B=A4?= Date: Wed, 12 Nov 2025 20:03:07 +0800 Subject: [PATCH 6/9] add attention meta --- .../models/module/transformer_decoder.py | 78 ++++++++++++++----- 1 file changed, 57 insertions(+), 21 deletions(-) diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index 3875753..9e34095 100644 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -20,6 +20,31 @@ print("ATTENTION_BACKEND: ", ATTENTION_BACKEND) +class AttentionMeta(object): + + def __init__(self): + self.seq_lens = None + self.cu_seqlens_q = None + self.cu_seqlens_k = None + self.max_seqlen_q = None + self.max_seqlen_k = None + self.total_seqlen_k = None + self.active_indices = None + + def update(self, seq_lens=None, + cu_seqlens_q=None, cu_seqlens_k=None, + max_seqlen_q=None, max_seqlen_k=None, + total_seqlen_k=None, + active_indices=None): + self.seq_lens = seq_lens + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_k = cu_seqlens_k + self.max_seqlen_q = max_seqlen_q + self.max_seqlen_k = max_seqlen_k + self.total_seqlen_k = total_seqlen_k + self.active_indices = active_indices + + class TransformerDecoder(nn.Module): def __init__( self, sos_id, eos_id, pad_id, odim, @@ -47,6 +72,7 @@ def __init__( self.tgt_word_prj.weight = self.tgt_word_emb.weight self.scale = (d_model ** 0.5) + self.scores_map = {} self.stride_map = {} self.filter_indexes = {} @@ -55,7 +81,12 @@ def __init__( def clear(self): for dec_layer in self.layer_stack: - dec_layer.clear() + dec_layer.clear() + + + def cal_seq_lens(self, mask): + mask = mask.squeeze(1) + return mask.sum(dim=1, dtype=torch.int32) def batch_beam_search(self, encoder_outputs, src_masks, @@ -71,7 +102,7 @@ def batch_beam_search(self, encoder_outputs, src_masks, # Init encoder_outputs = encoder_outputs.unsqueeze(1).repeat(1, B, 1, 1).view(M, Ti, H) src_mask = src_masks.unsqueeze(1).repeat(1, B, 1, 1).view(M, -1, Ti) - ys = torch.ones(M, 1, device=device).fill_(self.sos_id).long() + raw_ys = ys = torch.ones(M, 1, device=device).fill_(self.sos_id).long() caches: List[Optional[Tensor]] = [] for _ in range(self.n_layers): caches.append(torch.empty(M, 0, H, device=device, dtype=encoder_outputs.dtype)) @@ -94,11 +125,13 @@ def batch_beam_search(self, encoder_outputs, src_masks, active_mask = self.active_masks[M] active_indices = self.filter_indexes[M] last_t_logit = None + + attn_meta = AttentionMeta() # Autoregressive Prediction for t in range(maxlen): if ATTENTION_BACKEND == "FLASH_ATTN": - tgt_mask = ys + tgt_mask = raw_ys else: tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id) @@ -115,14 +148,22 @@ def expand(f, mask, indices, value=0.0, t=None): tgt_mask = tgt_mask[active_indices] t_src_mask = src_mask[active_indices] - for i, dec_layer in enumerate(self.layer_stack): + seq_lens = self.cal_seq_lens(t_src_mask) + seq_lens_cpu = seq_lens.cpu() + total_seqlen_k, max_seqlen_k = seq_lens_cpu.sum().item(), seq_lens_cpu.max().item() + attn_meta.update(seq_lens=seq_lens, + max_seqlen_k=max_seqlen_k, + total_seqlen_k=total_seqlen_k, + active_indices=active_indices) + + for i, dec_layer in enumerate(self.layer_stack): dec_output = dec_layer.forward( dec_output, t_encoder_outputs, tgt_mask, t_src_mask, cache=caches[i][active_indices], - active_indices=active_indices) + attn_meta=attn_meta) caches[i] = dec_output dec_output = self.layer_norm_out(dec_output) @@ -227,7 +268,6 @@ def get_ys_lengths(self, ys): return ys_lengths.int() - class DecoderLayer(nn.Module): def __init__(self, d_model, n_head, dropout): super().__init__() @@ -255,7 +295,7 @@ def clear(self): self.cross_attn.clear() def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, - cache=None, active_indices=None): + cache=None, attn_meta:AttentionMeta=None): x = dec_input residual = x x = self.self_attn_norm(x) @@ -272,7 +312,7 @@ def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, enc_output, mask=cross_attn_mask, is_cross=True, - active_indices=active_indices) + attn_meta=attn_meta) residual = x x = self.mlp_norm(x) @@ -293,7 +333,7 @@ def __init__(self, d_model, n_head, dropout=0.1): self.w_vs = nn.Linear(d_model, n_head * self.d_k) self.fc = nn.Linear(n_head * self.d_k, d_model) - def forward(self, q, k, v, mask=None, is_cross=False, active_indices=None): + def forward(self, q, k, v, mask=None, is_cross=False, attn_meta:AttentionMeta=None): raise NotImplementedError def clear(self): @@ -308,7 +348,7 @@ def __init__(self, d_model, n_head, dropout=0.1): super().__init__(d_model, n_head, dropout) self.attention = DecoderScaledDotProductAttention(temperature=self.d_k ** 0.5) - def forward(self, q, k, v, mask=None, is_cross=False, active_indices=None): + def forward(self, q, k, v, mask=None, is_cross=False, attn_meta:AttentionMeta=None): bs = q.size(0) q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k) @@ -351,7 +391,7 @@ def __init__(self, d_model, n_head, dropout=0.1): super().__init__(d_model, n_head, dropout) self.attention = DecoderTorchSDPA(temperature=self.d_k ** 0.5) - def forward(self, q, k, v, mask=None, is_cross=False, active_indices=None): + def forward(self, q, k, v, mask=None, is_cross=False, attn_meta:AttentionMeta=None): bs = q.size(0) q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k).transpose(1, 2) @@ -393,7 +433,7 @@ class DecoderMHAXFormers(BaseMultiHeadAttention): def __init__(self, d_model, n_head, dropout=0.1): super().__init__(d_model, n_head, dropout) - def forward(self, q, k, v, mask=None, is_cross=False, active_indices=None): + def forward(self, q, k, v, mask=None, is_cross=False, attn_meta:AttentionMeta=None): bs = q.size(0) # projection and transform to (batch*n_head, seq_len, head_dim) @@ -418,14 +458,12 @@ def __init__(self, d_model, n_head, dropout=0.1): self.cross_cache_k = None self.cross_cache_v = None - self.cross_cache_seqs = None def clear(self): self.cross_cache_k = None self.cross_cache_v = None - self.cross_cache_seqs = None - def forward(self, q, k, v, mask=None, is_cross=False, active_indices=None): + def forward(self, q, k, v, mask=None, is_cross=False, attn_meta:AttentionMeta=None): is_casual = not is_cross bs = q.size(0) @@ -444,14 +482,12 @@ def forward(self, q, k, v, mask=None, is_cross=False, active_indices=None): k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) - seq_lens = mask.sum(dim=1, dtype=torch.int32) self.cross_cache_k = k self.cross_cache_v = v - self.cross_cache_seqs = seq_lens - - seq_lens = self.cross_cache_seqs[active_indices] - total, _max_seq_k = seq_lens.sum().item(), seq_lens.max().item() + active_indices = attn_meta.active_indices + seq_lens = attn_meta.seq_lens + total, _max_seq_k = attn_meta.total_seqlen_k, attn_meta.max_seqlen_k bool_indices = torch.nonzero_static(bool_mask, size=total).squeeze(1) var_k = self.cross_cache_k[active_indices].view(-1, self.n_head, self.d_k)[bool_indices] var_v = self.cross_cache_v[active_indices].view(-1, self.n_head, self.d_k)[bool_indices] @@ -502,4 +538,4 @@ def __init__(self, d_model, max_len=5000): def forward(self, x): length = x.size(1) - return self.pe[:, :length].clone().detach() + return self.pe[:, :length].clone().detach() \ No newline at end of file From 8b8b6a5203e46561f0eff454e1bd63039c3b4297 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=94=AF=E5=8B=A4?= Date: Wed, 12 Nov 2025 20:25:46 +0800 Subject: [PATCH 7/9] only fa init meta --- .../models/module/transformer_decoder.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index 9e34095..6b1b7e0 100644 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -130,31 +130,30 @@ def batch_beam_search(self, encoder_outputs, src_masks, # Autoregressive Prediction for t in range(maxlen): - if ATTENTION_BACKEND == "FLASH_ATTN": - tgt_mask = raw_ys - else: - tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id) - - dec_output = self.tgt_word_emb(ys) * self.scale + self.positional_encoding(ys) def expand(f, mask, indices, value=0.0, t=None): if t is None: t = torch.full((len(mask), *list(f.shape[1:])), value, dtype=f.dtype, device=f.device) t[indices] = f return t - + + dec_output = self.tgt_word_emb(ys) * self.scale + self.positional_encoding(ys) dec_output = dec_output[active_indices] t_encoder_outputs = encoder_outputs[active_indices] - tgt_mask = tgt_mask[active_indices] t_src_mask = src_mask[active_indices] - seq_lens = self.cal_seq_lens(t_src_mask) - seq_lens_cpu = seq_lens.cpu() - total_seqlen_k, max_seqlen_k = seq_lens_cpu.sum().item(), seq_lens_cpu.max().item() - attn_meta.update(seq_lens=seq_lens, - max_seqlen_k=max_seqlen_k, - total_seqlen_k=total_seqlen_k, + if ATTENTION_BACKEND == "FLASH_ATTN": + tgt_mask = raw_ys[active_indices] + seq_lens = self.cal_seq_lens(t_src_mask) + seq_lens_cpu = seq_lens.cpu() + total_seqlen_k, max_seqlen_k = seq_lens_cpu.sum().item(), seq_lens_cpu.max().item() + attn_meta.update(seq_lens=seq_lens, + max_seqlen_k=max_seqlen_k, + total_seqlen_k=total_seqlen_k, active_indices=active_indices) + else: + tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id) + tgt_mask = tgt_mask[active_indices] for i, dec_layer in enumerate(self.layer_stack): dec_output = dec_layer.forward( @@ -538,4 +537,4 @@ def __init__(self, d_model, max_len=5000): def forward(self, x): length = x.size(1) - return self.pe[:, :length].clone().detach() \ No newline at end of file + return self.pe[:, :length].clone().detach() From d9aee69eb52c277426cb6ee51708b5f45e58b578 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=94=AF=E5=8B=A4?= Date: Fri, 14 Nov 2025 09:57:32 +0800 Subject: [PATCH 8/9] add xformers support --- .../models/module/transformer_decoder.py | 109 ++++++++++++++---- 1 file changed, 87 insertions(+), 22 deletions(-) diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index 6b1b7e0..b2a376f 100644 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -15,7 +15,7 @@ import einops -ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "SDPA") # Option: "NATIVE", "SDPA", "XFORMERS" +ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "SDPA") # Option: "NATIVE", "SDPA", "XFORMERS", "FLASH_ATTN" MultiHeadAttention = None print("ATTENTION_BACKEND: ", ATTENTION_BACKEND) @@ -142,7 +142,7 @@ def expand(f, mask, indices, value=0.0, t=None): t_encoder_outputs = encoder_outputs[active_indices] t_src_mask = src_mask[active_indices] - if ATTENTION_BACKEND == "FLASH_ATTN": + if ATTENTION_BACKEND in {"FLASH_ATTN", "XFORMERS"}: tgt_mask = raw_ys[active_indices] seq_lens = self.cal_seq_lens(t_src_mask) seq_lens_cpu = seq_lens.cpu() @@ -276,7 +276,7 @@ def __init__(self, d_model, n_head, dropout): elif ATTENTION_BACKEND.upper() == "SDPA": MultiHeadAttention = DecoderMHATorchSDPA elif ATTENTION_BACKEND.upper() == "XFORMERS": - MultiHeadAttention = DecoderMHAXFormers + MultiHeadAttention = DecoderXFormersAttention elif ATTENTION_BACKEND.upper() == "FLASH_ATTN": MultiHeadAttention = DecoderMHAFlashAttn else: @@ -414,11 +414,6 @@ def forward(self, q, k, v, mask=None): True means 'mask out'. - If float: same shape, with -inf for masked positions. """ - # F.scaled_dot_product_attention will: - # - scale internally - # - apply softmax - # - apply mask if given - # - compute attention output output = F.scaled_dot_product_attention( q, k, v, attn_mask=mask.eq(1), @@ -427,29 +422,99 @@ def forward(self, q, k, v, mask=None): ) return output -# MHA with xFormers -class DecoderMHAXFormers(BaseMultiHeadAttention): + +# xFormers Attention +class DecoderXFormersAttention(BaseMultiHeadAttention): def __init__(self, d_model, n_head, dropout=0.1): super().__init__(d_model, n_head, dropout) + self.attn_bias = None + + def set_cross_attn_bias(self, mask, bs, q_len, k_len, n_head, dtype, device, is_cross=False): + if is_cross: + mask = mask.to(torch.bool) + + # If mask only has 1 in q_len dimension, expand it + if mask.size(2) == 1 and q_len > 1: + mask = mask.expand(bs, 1, q_len, k_len) + + # Expand mask for all heads + mask = mask.expand(bs, n_head, q_len, k_len) \ + .reshape(bs * n_head, q_len, k_len) + + # Alignment requirement for xformers: pad allocation to multiple of 8 + pad_k = ((k_len + 7) // 8) * 8 + pad_q = ((q_len + 7) // 8) * 8 + + bias_full = torch.zeros(bs * n_head, pad_q, pad_k, + dtype=dtype, device=device) + + bias_full[:, :q_len, :k_len].masked_fill_(~mask, float("-inf")) + + # Slice down to actual shape but keep aligned backing storage + self.attn_bias = bias_full[:, :q_len, :k_len] + else: + print("Unknown attention type used, only support `cross_attention`") + + def get_attn_bias(self): + return self.attn_bias + + def reset_attn_bias(self): + self.attn_bias = None def forward(self, q, k, v, mask=None, is_cross=False, attn_meta:AttentionMeta=None): bs = q.size(0) + + q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k).transpose(1, 2) + k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k).transpose(1, 2) + v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k).transpose(1, 2) + + if mask is not None: + mask = mask.unsqueeze(1) + + original_query = q + + # Save lengths + q_len = q.size(2) # seq_len_q + k_len = k.size(2) # seq_len_k + dtype = q.dtype + + q = q.reshape(bs * self.n_head, -1, self.d_k).to(torch.float16) + k = k.reshape(bs * self.n_head, -1, self.d_k).to(torch.float16) + v = v.reshape(bs * self.n_head, -1, self.d_k).to(torch.float16) + + output = None + if bs == 1: + output = xops.memory_efficient_attention(q, k, v) + else: + attn_bias = None + # --- causal self-attention --- + # q and k has same length, pass attn_bias=None + if not is_cross: + attn_bias = None + + # --- Cross-attention / padding mask --- + elif is_cross and mask is not None: + self.set_cross_attn_bias(mask, bs, q_len, k_len, self.n_head, q.dtype, q.device, is_cross=is_cross) + attn_bias = self.get_attn_bias() + else: + print("Unknown attention type used, only support `self_attention` and `cross_attention`") - # projection and transform to (batch*n_head, seq_len, head_dim) - q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k).transpose(1, 2).reshape(bs * self.n_head, -1, self.d_k) - k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k).transpose(1, 2).reshape(bs * self.n_head, -1, self.d_k) - v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k).transpose(1, 2).reshape(bs * self.n_head, -1, self.d_k) - # single-step reshape+transpose - #q = einops.rearrange(self.w_qs(q), 'b s (h d) -> (b h) s d', h=self.n_head) - #k = einops.rearrange(self.w_ks(k), 'b s (h d) -> (b h) s d', h=self.n_head) - #v = einops.rearrange(self.w_vs(v), 'b s (h d) -> (b h) s d', h=self.n_head) - output = xops.memory_efficient_attention(q, k, v) - # back to (bs, seq_len, d_model) - output = output.reshape(bs, self.n_head, -1, self.d_k).transpose(1, 2).contiguous().view(bs, -1, self.d_model) - #output = einops.rearrange(output, '(b h) s d -> b s (h d)', b=bs, h=self.n_head) + # --- Run memory-efficient attention --- + + output = xops.memory_efficient_attention(q, k, v, + attn_bias=attn_bias) + # reshape back to (bs, seq_len, d_model) + output = output.view_as(original_query).to(dtype) + output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model) output = self.fc(output) return output + def clear(self): + self.reset_attn_bias() + + def parse_mask(self, mask): + return mask + class DecoderMHAFlashAttn(BaseMultiHeadAttention): def __init__(self, d_model, n_head, dropout=0.1): From 66f951276b080342450ac5466332159af15e04a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=94=AF=E5=8B=A4?= Date: Fri, 14 Nov 2025 10:16:00 +0800 Subject: [PATCH 9/9] xformers no need to use seq lens now --- fireredasr/models/module/transformer_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index b2a376f..2b083c0 100644 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -142,7 +142,7 @@ def expand(f, mask, indices, value=0.0, t=None): t_encoder_outputs = encoder_outputs[active_indices] t_src_mask = src_mask[active_indices] - if ATTENTION_BACKEND in {"FLASH_ATTN", "XFORMERS"}: + if ATTENTION_BACKEND in {"FLASH_ATTN"}: tgt_mask = raw_ys[active_indices] seq_lens = self.cal_seq_lens(t_src_mask) seq_lens_cpu = seq_lens.cpu()