diff --git a/diffsynth_engine/models/base.py b/diffsynth_engine/models/base.py index 648134a2..0471b1bd 100644 --- a/diffsynth_engine/models/base.py +++ b/diffsynth_engine/models/base.py @@ -60,6 +60,9 @@ def get_tp_plan(self): def get_fsdp_module_cls(self): raise NotImplementedError(f"{self.__class__.__name__} does not support FSDP") + def enable_fp8_linear(self): + raise NotImplementedError(f"{self.__class__.__name__} does not support FP8 linear") + def split_suffix(name: str): suffix_list = [ diff --git a/diffsynth_engine/models/basic/linear.py b/diffsynth_engine/models/basic/linear.py new file mode 100644 index 00000000..0982024f --- /dev/null +++ b/diffsynth_engine/models/basic/linear.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn +from typing import Tuple + +from diffsynth_engine.utils.platform import DTYPE_FP8, FP8_MAX + + +def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x_max = x.abs().float().amax(dim=-1, keepdim=True).clamp(min=1e-4) + scale = x_max / FP8_MAX + x_scaled = x / scale + return x_scaled, scale + + +def fp8_linear( + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None, scaling: bool = True +) -> torch.Tensor: + device = input.device + origin_dtype = input.dtype + origin_shape = input.shape + input = input.reshape(-1, origin_shape[-1]) + out_features, _ = weight.shape + + if scaling: + input, scale_a = per_token_cast_to_fp8(input) + scale_b = torch.ones((out_features, 1), device=device) + else: + scale_a = torch.tensor(1.0, device=device) + scale_b = torch.tensor(1.0, device=device) + input = input.to(DTYPE_FP8) + weight = weight.to(DTYPE_FP8) + + result = torch._scaled_mm( + input, + weight.T, + scale_a=scale_a, + scale_b=scale_b.T, + bias=bias, + out_dtype=origin_dtype, + ) + new_shape = origin_shape[:-1] + result.shape[-1:] + result = result.reshape(new_shape) + return result + + +class FP8Linear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + scaling: bool = True, + ): + super().__init__(in_features, out_features, bias, device, dtype) + self.weight.data = self.weight.data.to(DTYPE_FP8) + self.scaling = scaling + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return fp8_linear(input, self.weight, self.bias, self.scaling) diff --git a/diffsynth_engine/models/basic/lora.py b/diffsynth_engine/models/basic/lora.py index e4065bd4..621b1cf7 100644 --- a/diffsynth_engine/models/basic/lora.py +++ b/diffsynth_engine/models/basic/lora.py @@ -5,6 +5,9 @@ from collections import OrderedDict from contextlib import contextmanager +from diffsynth_engine.models.basic.linear import fp8_linear +from diffsynth_engine.utils.platform import DTYPE_FP8 + class LoRA(nn.Module): def __init__( @@ -164,6 +167,41 @@ def forward(self, x): return w_x +class LoRAFP8Linear(LoRALinear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + scaling: bool = True, + ) -> None: + super().__init__(in_features, out_features, bias, device, dtype) + self.weight.data = self.weight.data.to(DTYPE_FP8) + self.scaling = scaling + + @staticmethod + def from_linear(linear: nn.Linear, scaling: bool = True): + lora_linear = LoRAFP8Linear( + linear.in_features, + linear.out_features, + linear.bias is not None, + device="meta", + dtype=linear.weight.dtype, + scaling=scaling, + ).to_empty(device=linear.weight.device) + lora_linear.weight.data = linear.weight.data.to(DTYPE_FP8) + lora_linear.bias = linear.bias + return lora_linear + + def forward(self, x): + w_x = fp8_linear(x, self.weight, self.bias, self.scaling) # only use fp8 linear for base layer + for name, lora in self._lora_dict.items(): + w_x += lora(x) + return w_x + + class LoRAConv2d(nn.Conv2d): def __init__( self, diff --git a/diffsynth_engine/models/flux/flux_dit.py b/diffsynth_engine/models/flux/flux_dit.py index 767991b8..05b1c2e8 100644 --- a/diffsynth_engine/models/flux/flux_dit.py +++ b/diffsynth_engine/models/flux/flux_dit.py @@ -11,11 +11,11 @@ RoPEEmbedding, RMSNorm, ) +from diffsynth_engine.models.basic import attention as attention_ops +from diffsynth_engine.models.basic.lora import LoRAFP8Linear from diffsynth_engine.models.basic.timestep import TimestepEmbeddings from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter -from diffsynth_engine.models.basic import attention as attention_ops from diffsynth_engine.utils.gguf import gguf_inference -from diffsynth_engine.utils.fp8_linear import fp8_inference from diffsynth_engine.utils.constants import FLUX_DIT_CONFIG_FILE from diffsynth_engine.utils.parallel import ( cfg_parallel, @@ -405,10 +405,8 @@ def forward( controlnet_single_block_output if controlnet_single_block_output is not None else () ) - fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False) use_cfg = hidden_states.shape[0] > 1 with ( - fp8_inference(fp8_linear_enabled), gguf_inference(), cfg_parallel( ( @@ -506,3 +504,9 @@ def compile_repeated_blocks(self, *args, **kwargs): def get_fsdp_module_cls(self): return {FluxDoubleTransformerBlock, FluxSingleTransformerBlock} + + def enable_fp8_linear(self): + target_names = ["blocks", "single_blocks"] + for name, module in self.named_modules(): + if any([t in name for t in target_names]) and isinstance(module, nn.Linear): + self.set_submodule(name, LoRAFP8Linear.from_linear(module)) diff --git a/diffsynth_engine/models/flux/flux_dit_fbcache.py b/diffsynth_engine/models/flux/flux_dit_fbcache.py index 15c41c50..20916d7e 100644 --- a/diffsynth_engine/models/flux/flux_dit_fbcache.py +++ b/diffsynth_engine/models/flux/flux_dit_fbcache.py @@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional from diffsynth_engine.utils.gguf import gguf_inference -from diffsynth_engine.utils.fp8_linear import fp8_inference from diffsynth_engine.utils.parallel import ( cfg_parallel, cfg_parallel_unshard, @@ -68,10 +67,8 @@ def forward( controlnet_single_block_output if controlnet_single_block_output is not None else () ) - fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False) use_cfg = hidden_states.shape[0] > 1 with ( - fp8_inference(fp8_linear_enabled), gguf_inference(), cfg_parallel( ( diff --git a/diffsynth_engine/models/qwen_image/qwen_image_dit.py b/diffsynth_engine/models/qwen_image/qwen_image_dit.py index 01869e1b..62114963 100644 --- a/diffsynth_engine/models/qwen_image/qwen_image_dit.py +++ b/diffsynth_engine/models/qwen_image/qwen_image_dit.py @@ -7,8 +7,8 @@ from diffsynth_engine.models.basic import attention as attention_ops from diffsynth_engine.models.basic.timestep import TimestepEmbeddings from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, GELU, RMSNorm +from diffsynth_engine.models.basic.lora import LoRAFP8Linear from diffsynth_engine.utils.gguf import gguf_inference -from diffsynth_engine.utils.fp8_linear import fp8_inference from diffsynth_engine.utils.parallel import ( cfg_parallel, cfg_parallel_unshard, @@ -441,10 +441,8 @@ def forward( attn_kwargs: Optional[Dict[str, Any]] = None, ): h, w = image.shape[-2:] - fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False) use_cfg = image.shape[0] > 1 with ( - fp8_inference(fp8_linear_enabled), gguf_inference(), cfg_parallel( ( @@ -540,3 +538,9 @@ def compile_repeated_blocks(self, *args, **kwargs): def get_fsdp_module_cls(self): return {QwenImageTransformerBlock} + + def enable_fp8_linear(self): + target_names = ["transformer_blocks"] + for name, module in self.named_modules(): + if any([t in name for t in target_names]) and isinstance(module, nn.Linear): + self.set_submodule(name, LoRAFP8Linear.from_linear(module)) diff --git a/diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py b/diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py index 641168b5..4aab73f8 100644 --- a/diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +++ b/diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py @@ -3,7 +3,6 @@ from diffsynth_engine.models.qwen_image import QwenImageDiT from diffsynth_engine.utils.gguf import gguf_inference -from diffsynth_engine.utils.fp8_linear import fp8_inference from diffsynth_engine.utils.parallel import cfg_parallel, cfg_parallel_unshard @@ -45,10 +44,8 @@ def forward( attn_kwargs: Optional[Dict[str, Any]] = None, ): h, w = image.shape[-2:] - fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False) use_cfg = image.shape[0] > 1 with ( - fp8_inference(fp8_linear_enabled), gguf_inference(), cfg_parallel( ( diff --git a/diffsynth_engine/models/qwen_image/qwen_image_dit_nunchaku.py b/diffsynth_engine/models/qwen_image/qwen_image_dit_nunchaku.py index 3c580cea..64146157 100644 --- a/diffsynth_engine/models/qwen_image/qwen_image_dit_nunchaku.py +++ b/diffsynth_engine/models/qwen_image/qwen_image_dit_nunchaku.py @@ -339,3 +339,6 @@ def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = False): device=fuse_dict[key]["q"]["device"], dtype=fuse_dict[key]["q"]["dtype"], ) + + def enable_fp8_linear(self): + raise NotImplementedError(f"{self.__class__.__name__} does not support FP8 linear") diff --git a/diffsynth_engine/models/wan/wan_dit.py b/diffsynth_engine/models/wan/wan_dit.py index 86dc9d68..02639b47 100644 --- a/diffsynth_engine/models/wan/wan_dit.py +++ b/diffsynth_engine/models/wan/wan_dit.py @@ -9,6 +9,7 @@ from diffsynth_engine.models.basic.attention import attention from diffsynth_engine.models.basic import attention as attention_ops from diffsynth_engine.models.basic.transformer_helper import RMSNorm +from diffsynth_engine.models.basic.lora import LoRAFP8Linear from diffsynth_engine.utils.constants import ( WAN2_1_DIT_T2V_1_3B_CONFIG_FILE, WAN2_1_DIT_I2V_14B_CONFIG_FILE, @@ -20,7 +21,6 @@ WAN_DIT_KEYMAP_FILE, ) from diffsynth_engine.utils.gguf import gguf_inference -from diffsynth_engine.utils.fp8_linear import fp8_inference from diffsynth_engine.utils.parallel import ( cfg_parallel, cfg_parallel_unshard, @@ -386,10 +386,8 @@ def forward( y: Optional[torch.Tensor] = None, # vae_encoder(img) attn_kwargs: Optional[Dict[str, Any]] = None, ): - fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False) use_cfg = x.shape[0] > 1 with ( - fp8_inference(fp8_linear_enabled), gguf_inference(), cfg_parallel((x, context, timestep, clip_feature, y), use_cfg=use_cfg), ): @@ -541,3 +539,9 @@ def compile_repeated_blocks(self, *args, **kwargs): def get_fsdp_module_cls(self): return {DiTBlock} + + def enable_fp8_linear(self): + target_names = ["blocks"] + for name, module in self.named_modules(): + if any([t in name for t in target_names]) and isinstance(module, nn.Linear): + self.set_submodule(name, LoRAFP8Linear.from_linear(module)) diff --git a/diffsynth_engine/models/wan/wan_s2v_dit.py b/diffsynth_engine/models/wan/wan_s2v_dit.py index d0d21c5a..fe0f2b82 100644 --- a/diffsynth_engine/models/wan/wan_s2v_dit.py +++ b/diffsynth_engine/models/wan/wan_s2v_dit.py @@ -18,7 +18,6 @@ ) from diffsynth_engine.utils.constants import WAN2_2_DIT_S2V_14B_CONFIG_FILE from diffsynth_engine.utils.gguf import gguf_inference -from diffsynth_engine.utils.fp8_linear import fp8_inference from diffsynth_engine.utils.parallel import ( cfg_parallel, cfg_parallel_unshard, @@ -412,10 +411,8 @@ def forward( audio_mask: Optional[torch.Tensor] = None, # b c tx h w void_audio_input: Optional[torch.Tensor] = None, ): - fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False) use_cfg = x.shape[0] > 1 with ( - fp8_inference(fp8_linear_enabled), gguf_inference(), cfg_parallel((x, context, audio_input), use_cfg=use_cfg), ): diff --git a/diffsynth_engine/pipelines/base.py b/diffsynth_engine/pipelines/base.py index a836efc5..7261c682 100644 --- a/diffsynth_engine/pipelines/base.py +++ b/diffsynth_engine/pipelines/base.py @@ -15,7 +15,7 @@ ) from diffsynth_engine.models.basic.video_sparse_attention import get_vsa_kwargs from diffsynth_engine.utils.offload import enable_sequential_cpu_offload, offload_model_to_dict, restore_model_from_dict -from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast +from diffsynth_engine.utils.autocast import enable_fp8_autocast from diffsynth_engine.utils.gguf import load_gguf_checkpoint from diffsynth_engine.utils import logging from diffsynth_engine.utils.loader import load_file diff --git a/diffsynth_engine/pipelines/flux_image.py b/diffsynth_engine/pipelines/flux_image.py index e31a366a..22708f84 100644 --- a/diffsynth_engine/pipelines/flux_image.py +++ b/diffsynth_engine/pipelines/flux_image.py @@ -32,7 +32,6 @@ from diffsynth_engine.utils.constants import FLUX_TOKENIZER_1_CONF_PATH, FLUX_TOKENIZER_2_CONF_PATH from diffsynth_engine.utils.parallel import ParallelWrapper from diffsynth_engine.utils import logging -from diffsynth_engine.utils.fp8_linear import enable_fp8_linear from diffsynth_engine.utils.download import fetch_model from diffsynth_engine.utils.platform import empty_cache from diffsynth_engine.utils.constants import FLUX_DIT_CONFIG_FILE @@ -528,7 +527,7 @@ def _from_state_dict(cls, state_dicts: FluxStateDicts, config: FluxPipelineConfi in_channel=config.control_type.get_in_channel(), ) if config.use_fp8_linear: - enable_fp8_linear(dit) + dit.enable_fp8_linear() pipe = cls( config=config, diff --git a/diffsynth_engine/pipelines/qwen_image.py b/diffsynth_engine/pipelines/qwen_image.py index 5cd90a3b..a608bded 100644 --- a/diffsynth_engine/pipelines/qwen_image.py +++ b/diffsynth_engine/pipelines/qwen_image.py @@ -2,7 +2,6 @@ import torch import torch.distributed as dist import math -import sys from typing import Callable, List, Dict, Tuple, Optional, Union from tqdm import tqdm from einops import rearrange @@ -37,7 +36,6 @@ ) from diffsynth_engine.utils.parallel import ParallelWrapper from diffsynth_engine.utils import logging -from diffsynth_engine.utils.fp8_linear import enable_fp8_linear from diffsynth_engine.utils.download import fetch_model from diffsynth_engine.utils.flag import NUNCHAKU_AVAILABLE @@ -45,7 +43,6 @@ logger = logging.get_logger(__name__) - class QwenImageLoRAConverter(LoRAStateDictConverter): def _from_diffsynth(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]: dit_dict = {} @@ -205,7 +202,7 @@ def _setup_nunchaku_config( else: config.use_nunchaku_attn = False logger.info("Disable nunchaku attention quantization.") - + else: config.use_nunchaku = False @@ -318,6 +315,7 @@ def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePip elif config.use_nunchaku: if not NUNCHAKU_AVAILABLE: from diffsynth_engine.utils.flag import NUNCHAKU_IMPORT_ERROR + raise ImportError(NUNCHAKU_IMPORT_ERROR) from diffsynth_engine.models.qwen_image import QwenImageDiTNunchaku @@ -339,7 +337,7 @@ def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePip dtype=config.model_dtype, ) if config.use_fp8_linear and not config.use_nunchaku: - enable_fp8_linear(dit) + dit.enable_fp8_linear() pipe = cls( config=config, diff --git a/diffsynth_engine/pipelines/wan_s2v.py b/diffsynth_engine/pipelines/wan_s2v.py index 3f3f3a4a..836c2ecd 100644 --- a/diffsynth_engine/pipelines/wan_s2v.py +++ b/diffsynth_engine/pipelines/wan_s2v.py @@ -24,7 +24,6 @@ from diffsynth_engine.tokenizers import WanT5Tokenizer from diffsynth_engine.utils.constants import WAN_TOKENIZER_CONF_PATH from diffsynth_engine.utils.download import fetch_model -from diffsynth_engine.utils.fp8_linear import enable_fp8_linear from diffsynth_engine.utils.image import resize_and_center_crop from diffsynth_engine.utils.video import read_n_frames from diffsynth_engine.utils.parallel import ParallelWrapper @@ -666,7 +665,7 @@ def _from_state_dict( use_vsa=(config.dit_attn_impl.value == "vsa"), ) if config.use_fp8_linear: - enable_fp8_linear(dit) + dit.enable_fp8_linear() pipe = cls( config=config, diff --git a/diffsynth_engine/pipelines/wan_video.py b/diffsynth_engine/pipelines/wan_video.py index 116effb8..e485ef13 100644 --- a/diffsynth_engine/pipelines/wan_video.py +++ b/diffsynth_engine/pipelines/wan_video.py @@ -16,7 +16,6 @@ from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter from diffsynth_engine.utils.constants import WAN_TOKENIZER_CONF_PATH from diffsynth_engine.utils.download import fetch_model -from diffsynth_engine.utils.fp8_linear import enable_fp8_linear from diffsynth_engine.utils.parallel import ParallelWrapper from diffsynth_engine.utils import logging @@ -594,7 +593,7 @@ def _from_state_dict(cls, state_dicts: WanStateDicts, config: WanPipelineConfig) use_vsa=(config.dit_attn_impl.value == "vsa"), ) if config.use_fp8_linear: - enable_fp8_linear(dit) + dit.enable_fp8_linear() dit2 = None if dit2_state_dict is not None: @@ -606,7 +605,7 @@ def _from_state_dict(cls, state_dicts: WanStateDicts, config: WanPipelineConfig) use_vsa=(config.dit_attn_impl.value == "vsa"), ) if config.use_fp8_linear: - enable_fp8_linear(dit2) + dit2.enable_fp8_linear() pipe = cls( config=config, @@ -681,8 +680,9 @@ def has_any_key(*xs): config.attn_params = VideoSparseAttentionParams(sparsity=0.9) def update_weights(self, state_dicts: WanStateDicts) -> None: - is_dual_model_state_dict = (isinstance(state_dicts.model, dict) and - ("high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model)) + is_dual_model_state_dict = isinstance(state_dicts.model, dict) and ( + "high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model + ) is_dual_model_pipeline = self.dit2 is not None if is_dual_model_state_dict != is_dual_model_pipeline: @@ -694,15 +694,21 @@ def update_weights(self, state_dicts: WanStateDicts) -> None: if is_dual_model_state_dict: if "high_noise_model" in state_dicts.model: - self.update_component(self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype) + self.update_component( + self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype + ) if "low_noise_model" in state_dicts.model: - self.update_component(self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype) + self.update_component( + self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype + ) else: self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype) self.update_component(self.text_encoder, state_dicts.t5, self.config.device, self.config.t5_dtype) self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype) - self.update_component(self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype) + self.update_component( + self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype + ) def compile(self): self.dit.compile_repeated_blocks() diff --git a/diffsynth_engine/utils/autocast.py b/diffsynth_engine/utils/autocast.py new file mode 100644 index 00000000..64d55ceb --- /dev/null +++ b/diffsynth_engine/utils/autocast.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn + + +def enable_fp8_autocast(module: nn.Module, compute_dtype: torch.dtype = torch.bfloat16, use_fp8_linear: bool = False): + if len(list(module.children())) == 0: + if len(list(module.parameters())) > 0: + add_fp8_autocast_hook(module, compute_dtype) + return + if len(list(module.parameters(recurse=False))) > 0: + add_fp8_autocast_hook(module, compute_dtype) + for submodule in module.children(): + if isinstance(submodule, nn.Linear) and use_fp8_linear: + continue + + enable_fp8_autocast(submodule, compute_dtype, use_fp8_linear) + + +def add_fp8_autocast_hook(module: nn.Module, compute_dtype: torch.dtype = torch.bfloat16): + def _fp8_autocast_pre_hook(module: nn.Module, input_): + for name, param in module.named_parameters(): + if param.dtype == torch.float8_e4m3fn: + param.data = param.data.to(compute_dtype) + new_inputs = [] + for x in input_: + if isinstance(x, torch.Tensor) and x.dtype in [torch.float8_e4m3fn, torch.float16, torch.bfloat16]: + new_inputs.append(x.to(compute_dtype)) + else: + new_inputs.append(x) + return tuple(new_inputs) + + def _fp8_autocast_hook(module: nn.Module, input_, output_): + for name, param in module.named_parameters(): + if param.dtype == compute_dtype: + param.data = param.data.to(torch.float8_e4m3fn) + + if getattr(module, "_fp8_autocast_enabled", False): + return + module.register_forward_pre_hook(_fp8_autocast_pre_hook) + module.register_forward_hook(_fp8_autocast_hook) + setattr(module, "_fp8_autocast_enabled", True) diff --git a/diffsynth_engine/utils/fp8_linear.py b/diffsynth_engine/utils/fp8_linear.py deleted file mode 100644 index b4c2ea73..00000000 --- a/diffsynth_engine/utils/fp8_linear.py +++ /dev/null @@ -1,106 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from contextlib import contextmanager -from diffsynth_engine.utils.platform import DTYPE_FP8 - - -def enable_fp8_autocast(module: nn.Module, compute_dtype: torch.dtype = torch.bfloat16, use_fp8_linear: bool = False): - if len(list(module.children())) == 0: - if len(list(module.parameters())) > 0: - add_fp8_autocast_hook(module, compute_dtype) - return - if len(list(module.parameters(recurse=False))) > 0: - add_fp8_autocast_hook(module, compute_dtype) - for submodule in module.children(): - if isinstance(submodule, nn.Linear) and use_fp8_linear: - continue - - enable_fp8_autocast(submodule, compute_dtype, use_fp8_linear) - - -def add_fp8_autocast_hook(module: nn.Module, compute_dtype: torch.dtype = torch.bfloat16): - def _fp8_autocast_pre_hook(module: nn.Module, input_): - for name, param in module.named_parameters(): - if param.dtype == torch.float8_e4m3fn: - param.data = param.data.to(compute_dtype) - new_inputs = [] - for x in input_: - if isinstance(x, torch.Tensor) and x.dtype in [torch.float8_e4m3fn, torch.float16, torch.bfloat16]: - new_inputs.append(x.to(compute_dtype)) - else: - new_inputs.append(x) - return tuple(new_inputs) - - def _fp8_autocast_hook(module: nn.Module, input_, output_): - for name, param in module.named_parameters(): - if param.dtype == compute_dtype: - param.data = param.data.to(torch.float8_e4m3fn) - - if getattr(module, "_fp8_autocast_enabled", False): - return - module.register_forward_pre_hook(_fp8_autocast_pre_hook) - module.register_forward_hook(_fp8_autocast_hook) - setattr(module, "_fp8_autocast_enabled", True) - - -def enable_fp8_linear(module: nn.Module): - _enable_fp8_linear(module) - setattr(module, "fp8_linear_enabled", True) - - -def _enable_fp8_linear(module: nn.Module): - if isinstance(module, nn.Linear) and torch.is_floating_point(module.weight.data): - # avoid conversion for int weights like GGUF - module.weight.data = module.weight.data.to(DTYPE_FP8) - for submodule in module.children(): - _enable_fp8_linear(submodule) - - -@contextmanager -def fp8_inference(enabled=True): - if not enabled: - yield - return - - origin_linear = F.linear - - def fp8_linear( - input: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - device = input.device - origin_dtype = input.dtype - origin_shape = input.shape - input = input.reshape(-1, origin_shape[-1]) - - x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values - fp8_max = 448.0 - # For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn. - # To avoid overflow and ensure numerical compatibility during FP8 computation, - # we scale down the input by 2.0 in advance. - # This scaling will be compensated later during the final result scaling. - if DTYPE_FP8 == torch.float8_e4m3fnuz: - fp8_max = fp8_max / 2.0 - scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device) - scale_b = torch.ones((weight.shape[0], 1)).float().to(device=device) - input = input / scale_a - input = input.to(DTYPE_FP8) - weight = weight.to(DTYPE_FP8) - - result = torch._scaled_mm( - input, - weight.T, - scale_a=scale_a, - scale_b=scale_b.T, - bias=bias, - out_dtype=origin_dtype, - ) - new_shape = origin_shape[:-1] + result.shape[-1:] - result = result.reshape(new_shape) - return result - - F.linear = fp8_linear - yield - F.linear = origin_linear diff --git a/diffsynth_engine/utils/platform.py b/diffsynth_engine/utils/platform.py index 49a69680..ee50d7ca 100644 --- a/diffsynth_engine/utils/platform.py +++ b/diffsynth_engine/utils/platform.py @@ -12,6 +12,8 @@ else: DTYPE_FP8 = torch.float8_e4m3fn +FP8_MAX = torch.finfo(DTYPE_FP8).max + def empty_cache(): gc.collect()