1313 SAGE_ATTN_AVAILABLE ,
1414 SPARGE_ATTN_AVAILABLE ,
1515)
16+ from diffsynth_engine .utils .platform import DTYPE_FP8
1617
1718FA3_MAX_HEADDIM = 256
1819
@@ -125,12 +126,13 @@ def attention(
125126 None ,
126127 "auto" ,
127128 "eager" ,
128- "flash_attn_2" ,
129- "flash_attn_3" ,
129+ "fa2" ,
130+ "fa3" ,
131+ "fa3_fp8" ,
130132 "xformers" ,
131133 "sdpa" ,
132- "sage_attn " ,
133- "sparge_attn " ,
134+ "sage " ,
135+ "sparge " ,
134136 ]
135137 flash_attn3_compatible = q .shape [- 1 ] <= FA3_MAX_HEADDIM
136138 if attn_impl is None or attn_impl == "auto" :
@@ -139,9 +141,13 @@ def attention(
139141 return flash_attn3 (q , k , v , softmax_scale = scale )
140142 else :
141143 if not flash_attn3_compatible :
142- logger .warning (f"head_dim={ q .shape [- 1 ]} , but flash_attn_3 only supports head dimension at most { FA3_MAX_HEADDIM } , will use fallback attention implementation" )
144+ logger .warning (
145+ f"head_dim={ q .shape [- 1 ]} , but flash_attn_3 only supports head dimension at most { FA3_MAX_HEADDIM } , will use fallback attention implementation"
146+ )
143147 else :
144- logger .debug ("flash_attn_3 does not support attention mask, will use fallback attention implementation" )
148+ logger .debug (
149+ "flash_attn_3 does not support attention mask, will use fallback attention implementation"
150+ )
145151 if XFORMERS_AVAILABLE :
146152 return xformers_attn (q , k , v , attn_mask = attn_mask , scale = scale )
147153 if SDPA_AVAILABLE :
@@ -152,23 +158,31 @@ def attention(
152158 else :
153159 if attn_impl == "eager" :
154160 return eager_attn (q , k , v , attn_mask = attn_mask , scale = scale )
155- if attn_impl == "flash_attn_3 " :
161+ if attn_impl == "fa3" or attn_impl == "fa3_fp8 " :
156162 if not flash_attn3_compatible :
157163 raise RuntimeError (
158164 f"head_dim={ q .shape [- 1 ]} , but flash_attn_3 only supports head dimension at most { FA3_MAX_HEADDIM } "
159165 )
160166 if attn_mask is not None :
161167 raise RuntimeError ("flash_attn_3 does not support attention mask" )
162- return flash_attn3 (q , k , v , softmax_scale = scale )
163- if attn_impl == "flash_attn_2" :
168+ if attn_impl == "fa3" :
169+ return flash_attn3 (q , k , v , softmax_scale = scale )
170+ else :
171+ origin_dtype = q .dtype
172+ q = q .to (dtype = DTYPE_FP8 )
173+ k = k .to (dtype = DTYPE_FP8 )
174+ v = v .to (dtype = DTYPE_FP8 )
175+ out = flash_attn3 (q , k , v , softmax_scale = scale )
176+ return out .to (dtype = origin_dtype )
177+ if attn_impl == "fa2" :
164178 return flash_attn2 (q , k , v , softmax_scale = scale )
165179 if attn_impl == "xformers" :
166180 return xformers_attn (q , k , v , attn_mask = attn_mask , scale = scale )
167181 if attn_impl == "sdpa" :
168182 return sdpa_attn (q , k , v , attn_mask = attn_mask , scale = scale )
169- if attn_impl == "sage_attn " :
183+ if attn_impl == "sage " :
170184 return sage_attn (q , k , v , attn_mask = attn_mask , scale = scale )
171- if attn_impl == "sparge_attn " :
185+ if attn_impl == "sparge " :
172186 return sparge_attn (
173187 q ,
174188 k ,
@@ -247,12 +261,14 @@ def long_context_attention(
247261 assert attn_impl in [
248262 None ,
249263 "auto" ,
250- "flash_attn_2" ,
251- "flash_attn_3" ,
264+ "fa2" ,
265+ "fa3" ,
266+ "fa3_fp8" ,
252267 "sdpa" ,
253- "sage_attn " ,
254- "sparge_attn " ,
268+ "sage " ,
269+ "sparge " ,
255270 ]
271+ assert attn_mask is None , "long context attention does not support attention mask"
256272 flash_attn3_compatible = q .shape [- 1 ] <= FA3_MAX_HEADDIM
257273 if attn_impl is None or attn_impl == "auto" :
258274 if FLASH_ATTN_3_AVAILABLE :
@@ -268,20 +284,27 @@ def long_context_attention(
268284 return LongContextAttention (attn_type = AttnType .FA )(q , k , v , softmax_scale = scale )
269285 raise ValueError ("No available long context attention implementation" )
270286 else :
271- if attn_impl == "flash_attn_3" :
272- if flash_attn3_compatible :
273- return LongContextAttention (attn_type = AttnType .FA3 )(q , k , v , softmax_scale = scale )
274- else :
287+ if attn_impl == "fa3" or attn_impl == "fa3_fp8" :
288+ if not flash_attn3_compatible :
275289 raise RuntimeError (
276290 f"head_dim={ q .shape [- 1 ]} , but flash_attn_3 only supports head dimension at most { FA3_MAX_HEADDIM } "
277291 )
278- if attn_impl == "flash_attn_2" :
292+ if attn_impl == "fa3" :
293+ return LongContextAttention (attn_type = AttnType .FA3 )(q , k , v , softmax_scale = scale )
294+
295+ origin_dtype = q .dtype
296+ q = q .to (dtype = DTYPE_FP8 )
297+ k = k .to (dtype = DTYPE_FP8 )
298+ v = v .to (dtype = DTYPE_FP8 )
299+ out = LongContextAttention (attn_type = AttnType .FA3 )(q , k , v , softmax_scale = scale )
300+ return out .to (dtype = origin_dtype )
301+ if attn_impl == "fa2" :
279302 return LongContextAttention (attn_type = AttnType .FA )(q , k , v , softmax_scale = scale )
280303 if attn_impl == "sdpa" :
281304 return LongContextAttention (attn_type = AttnType .TORCH )(q , k , v , softmax_scale = scale )
282- if attn_impl == "sage_attn " :
283- return LongContextAttention (attn_type = AttnType .SAGE_FP8 )(q , k , v , softmax_scale = scale )
284- if attn_impl == "sparge_attn " :
305+ if attn_impl == "sage " :
306+ return LongContextAttention (attn_type = AttnType .SAGE_AUTO )(q , k , v , softmax_scale = scale )
307+ if attn_impl == "sparge " :
285308 attn_processor = SparseAttentionMeansim ()
286309 # default args from spas_sage2_attn_meansim_cuda
287310 attn_processor .smooth_k = torch .tensor (kwargs .get ("sparge_smooth_k" , True ))
0 commit comments