Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions diffsynth_engine/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def get_tp_plan(self):
def get_fsdp_module_cls(self):
raise NotImplementedError(f"{self.__class__.__name__} does not support FSDP")

def enable_fp8_linear(self):
raise NotImplementedError(f"{self.__class__.__name__} does not support FP8 linear")


def split_suffix(name: str):
suffix_list = [
Expand Down
61 changes: 61 additions & 0 deletions diffsynth_engine/models/basic/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch
import torch.nn as nn
from typing import Tuple

from diffsynth_engine.utils.platform import DTYPE_FP8, FP8_MAX


def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x_max = x.abs().float().amax(dim=-1, keepdim=True).clamp(min=1e-4)
scale = x_max / FP8_MAX
x_scaled = x / scale
return x_scaled, scale


def fp8_linear(
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None, scaling: bool = True
) -> torch.Tensor:
device = input.device
origin_dtype = input.dtype
origin_shape = input.shape
input = input.reshape(-1, origin_shape[-1])
out_features, _ = weight.shape

if scaling:
input, scale_a = per_token_cast_to_fp8(input)
scale_b = torch.ones((out_features, 1), device=device)
else:
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
input = input.to(DTYPE_FP8)
weight = weight.to(DTYPE_FP8)

result = torch._scaled_mm(
input,
weight.T,
scale_a=scale_a,
scale_b=scale_b.T,
bias=bias,
out_dtype=origin_dtype,
)
new_shape = origin_shape[:-1] + result.shape[-1:]
result = result.reshape(new_shape)
return result


class FP8Linear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
scaling: bool = True,
):
super().__init__(in_features, out_features, bias, device, dtype)
self.weight.data = self.weight.data.to(DTYPE_FP8)
self.scaling = scaling

def forward(self, input: torch.Tensor) -> torch.Tensor:
return fp8_linear(input, self.weight, self.bias, self.scaling)
38 changes: 38 additions & 0 deletions diffsynth_engine/models/basic/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from collections import OrderedDict
from contextlib import contextmanager

from diffsynth_engine.models.basic.linear import fp8_linear
from diffsynth_engine.utils.platform import DTYPE_FP8


class LoRA(nn.Module):
def __init__(
Expand Down Expand Up @@ -164,6 +167,41 @@ def forward(self, x):
return w_x


class LoRAFP8Linear(LoRALinear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
scaling: bool = True,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
self.weight.data = self.weight.data.to(DTYPE_FP8)
self.scaling = scaling

@staticmethod
def from_linear(linear: nn.Linear, scaling: bool = True):
lora_linear = LoRAFP8Linear(
linear.in_features,
linear.out_features,
linear.bias is not None,
device="meta",
dtype=linear.weight.dtype,
scaling=scaling,
).to_empty(device=linear.weight.device)
lora_linear.weight.data = linear.weight.data.to(DTYPE_FP8)
lora_linear.bias = linear.bias
return lora_linear

def forward(self, x):
w_x = fp8_linear(x, self.weight, self.bias, self.scaling) # only use fp8 linear for base layer
for name, lora in self._lora_dict.items():
w_x += lora(x)
return w_x


class LoRAConv2d(nn.Conv2d):
def __init__(
self,
Expand Down
12 changes: 8 additions & 4 deletions diffsynth_engine/models/flux/flux_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
RoPEEmbedding,
RMSNorm,
)
from diffsynth_engine.models.basic import attention as attention_ops
from diffsynth_engine.models.basic.lora import LoRAFP8Linear
from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
from diffsynth_engine.models.basic import attention as attention_ops
from diffsynth_engine.utils.gguf import gguf_inference
from diffsynth_engine.utils.fp8_linear import fp8_inference
from diffsynth_engine.utils.constants import FLUX_DIT_CONFIG_FILE
from diffsynth_engine.utils.parallel import (
cfg_parallel,
Expand Down Expand Up @@ -405,10 +405,8 @@ def forward(
controlnet_single_block_output if controlnet_single_block_output is not None else ()
)

fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
use_cfg = hidden_states.shape[0] > 1
with (
fp8_inference(fp8_linear_enabled),
gguf_inference(),
cfg_parallel(
(
Expand Down Expand Up @@ -506,3 +504,9 @@ def compile_repeated_blocks(self, *args, **kwargs):

def get_fsdp_module_cls(self):
return {FluxDoubleTransformerBlock, FluxSingleTransformerBlock}

def enable_fp8_linear(self):
target_names = ["blocks", "single_blocks"]
for name, module in self.named_modules():
if any([t in name for t in target_names]) and isinstance(module, nn.Linear):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition any([t in name for t in target_names]) to identify target modules is a bit loose and could lead to incorrect module replacement if a module name incidentally contains one of the target names (e.g., a new module named archived_blocks). Using startswith would be more robust and ensure that only modules within the specified parent modules are matched.

Suggested change
if any([t in name for t in target_names]) and isinstance(module, nn.Linear):
if any(name.startswith(f"{t}.") for t in target_names) and isinstance(module, nn.Linear):

self.set_submodule(name, LoRAFP8Linear.from_linear(module))
3 changes: 0 additions & 3 deletions diffsynth_engine/models/flux/flux_dit_fbcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, Dict, List, Optional

from diffsynth_engine.utils.gguf import gguf_inference
from diffsynth_engine.utils.fp8_linear import fp8_inference
from diffsynth_engine.utils.parallel import (
cfg_parallel,
cfg_parallel_unshard,
Expand Down Expand Up @@ -68,10 +67,8 @@ def forward(
controlnet_single_block_output if controlnet_single_block_output is not None else ()
)

fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
use_cfg = hidden_states.shape[0] > 1
with (
fp8_inference(fp8_linear_enabled),
gguf_inference(),
cfg_parallel(
(
Expand Down
10 changes: 7 additions & 3 deletions diffsynth_engine/models/qwen_image/qwen_image_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from diffsynth_engine.models.basic import attention as attention_ops
from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, GELU, RMSNorm
from diffsynth_engine.models.basic.lora import LoRAFP8Linear
from diffsynth_engine.utils.gguf import gguf_inference
from diffsynth_engine.utils.fp8_linear import fp8_inference
from diffsynth_engine.utils.parallel import (
cfg_parallel,
cfg_parallel_unshard,
Expand Down Expand Up @@ -441,10 +441,8 @@ def forward(
attn_kwargs: Optional[Dict[str, Any]] = None,
):
h, w = image.shape[-2:]
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
use_cfg = image.shape[0] > 1
with (
fp8_inference(fp8_linear_enabled),
gguf_inference(),
cfg_parallel(
(
Expand Down Expand Up @@ -540,3 +538,9 @@ def compile_repeated_blocks(self, *args, **kwargs):

def get_fsdp_module_cls(self):
return {QwenImageTransformerBlock}

def enable_fp8_linear(self):
target_names = ["transformer_blocks"]
for name, module in self.named_modules():
if any([t in name for t in target_names]) and isinstance(module, nn.Linear):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition any([t in name for t in target_names]) to identify target modules is a bit loose and could lead to incorrect module replacement if a module name incidentally contains one of the target names. Using startswith would be more robust and ensure that only modules within the specified parent modules are matched.

Suggested change
if any([t in name for t in target_names]) and isinstance(module, nn.Linear):
if any(name.startswith(f"{t}.") for t in target_names) and isinstance(module, nn.Linear):

self.set_submodule(name, LoRAFP8Linear.from_linear(module))
3 changes: 0 additions & 3 deletions diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from diffsynth_engine.models.qwen_image import QwenImageDiT
from diffsynth_engine.utils.gguf import gguf_inference
from diffsynth_engine.utils.fp8_linear import fp8_inference
from diffsynth_engine.utils.parallel import cfg_parallel, cfg_parallel_unshard


Expand Down Expand Up @@ -45,10 +44,8 @@ def forward(
attn_kwargs: Optional[Dict[str, Any]] = None,
):
h, w = image.shape[-2:]
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
use_cfg = image.shape[0] > 1
with (
fp8_inference(fp8_linear_enabled),
gguf_inference(),
cfg_parallel(
(
Expand Down
3 changes: 3 additions & 0 deletions diffsynth_engine/models/qwen_image/qwen_image_dit_nunchaku.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,6 @@ def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = False):
device=fuse_dict[key]["q"]["device"],
dtype=fuse_dict[key]["q"]["dtype"],
)

def enable_fp8_linear(self):
raise NotImplementedError(f"{self.__class__.__name__} does not support FP8 linear")
10 changes: 7 additions & 3 deletions diffsynth_engine/models/wan/wan_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from diffsynth_engine.models.basic.attention import attention
from diffsynth_engine.models.basic import attention as attention_ops
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
from diffsynth_engine.models.basic.lora import LoRAFP8Linear
from diffsynth_engine.utils.constants import (
WAN2_1_DIT_T2V_1_3B_CONFIG_FILE,
WAN2_1_DIT_I2V_14B_CONFIG_FILE,
Expand All @@ -20,7 +21,6 @@
WAN_DIT_KEYMAP_FILE,
)
from diffsynth_engine.utils.gguf import gguf_inference
from diffsynth_engine.utils.fp8_linear import fp8_inference
from diffsynth_engine.utils.parallel import (
cfg_parallel,
cfg_parallel_unshard,
Expand Down Expand Up @@ -386,10 +386,8 @@ def forward(
y: Optional[torch.Tensor] = None, # vae_encoder(img)
attn_kwargs: Optional[Dict[str, Any]] = None,
):
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
use_cfg = x.shape[0] > 1
with (
fp8_inference(fp8_linear_enabled),
gguf_inference(),
cfg_parallel((x, context, timestep, clip_feature, y), use_cfg=use_cfg),
):
Expand Down Expand Up @@ -541,3 +539,9 @@ def compile_repeated_blocks(self, *args, **kwargs):

def get_fsdp_module_cls(self):
return {DiTBlock}

def enable_fp8_linear(self):
target_names = ["blocks"]
for name, module in self.named_modules():
if any([t in name for t in target_names]) and isinstance(module, nn.Linear):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition any([t in name for t in target_names]) to identify target modules is a bit loose and could lead to incorrect module replacement if a module name incidentally contains one of the target names. Using startswith would be more robust and ensure that only modules within the specified parent modules are matched.

Suggested change
if any([t in name for t in target_names]) and isinstance(module, nn.Linear):
if any(name.startswith(f"{t}.") for t in target_names) and isinstance(module, nn.Linear):

self.set_submodule(name, LoRAFP8Linear.from_linear(module))
3 changes: 0 additions & 3 deletions diffsynth_engine/models/wan/wan_s2v_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
)
from diffsynth_engine.utils.constants import WAN2_2_DIT_S2V_14B_CONFIG_FILE
from diffsynth_engine.utils.gguf import gguf_inference
from diffsynth_engine.utils.fp8_linear import fp8_inference
from diffsynth_engine.utils.parallel import (
cfg_parallel,
cfg_parallel_unshard,
Expand Down Expand Up @@ -412,10 +411,8 @@ def forward(
audio_mask: Optional[torch.Tensor] = None, # b c tx h w
void_audio_input: Optional[torch.Tensor] = None,
):
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
use_cfg = x.shape[0] > 1
with (
fp8_inference(fp8_linear_enabled),
gguf_inference(),
cfg_parallel((x, context, audio_input), use_cfg=use_cfg),
):
Expand Down
2 changes: 1 addition & 1 deletion diffsynth_engine/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from diffsynth_engine.models.basic.video_sparse_attention import get_vsa_kwargs
from diffsynth_engine.utils.offload import enable_sequential_cpu_offload, offload_model_to_dict, restore_model_from_dict
from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
from diffsynth_engine.utils.autocast import enable_fp8_autocast
from diffsynth_engine.utils.gguf import load_gguf_checkpoint
from diffsynth_engine.utils import logging
from diffsynth_engine.utils.loader import load_file
Expand Down
3 changes: 1 addition & 2 deletions diffsynth_engine/pipelines/flux_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from diffsynth_engine.utils.constants import FLUX_TOKENIZER_1_CONF_PATH, FLUX_TOKENIZER_2_CONF_PATH
from diffsynth_engine.utils.parallel import ParallelWrapper
from diffsynth_engine.utils import logging
from diffsynth_engine.utils.fp8_linear import enable_fp8_linear
from diffsynth_engine.utils.download import fetch_model
from diffsynth_engine.utils.platform import empty_cache
from diffsynth_engine.utils.constants import FLUX_DIT_CONFIG_FILE
Expand Down Expand Up @@ -528,7 +527,7 @@ def _from_state_dict(cls, state_dicts: FluxStateDicts, config: FluxPipelineConfi
in_channel=config.control_type.get_in_channel(),
)
if config.use_fp8_linear:
enable_fp8_linear(dit)
dit.enable_fp8_linear()

pipe = cls(
config=config,
Expand Down
8 changes: 3 additions & 5 deletions diffsynth_engine/pipelines/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch
import torch.distributed as dist
import math
import sys
from typing import Callable, List, Dict, Tuple, Optional, Union
from tqdm import tqdm
from einops import rearrange
Expand Down Expand Up @@ -37,15 +36,13 @@
)
from diffsynth_engine.utils.parallel import ParallelWrapper
from diffsynth_engine.utils import logging
from diffsynth_engine.utils.fp8_linear import enable_fp8_linear
from diffsynth_engine.utils.download import fetch_model
from diffsynth_engine.utils.flag import NUNCHAKU_AVAILABLE


logger = logging.get_logger(__name__)



class QwenImageLoRAConverter(LoRAStateDictConverter):
def _from_diffsynth(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
dit_dict = {}
Expand Down Expand Up @@ -205,7 +202,7 @@ def _setup_nunchaku_config(
else:
config.use_nunchaku_attn = False
logger.info("Disable nunchaku attention quantization.")

else:
config.use_nunchaku = False

Expand Down Expand Up @@ -318,6 +315,7 @@ def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePip
elif config.use_nunchaku:
if not NUNCHAKU_AVAILABLE:
from diffsynth_engine.utils.flag import NUNCHAKU_IMPORT_ERROR

raise ImportError(NUNCHAKU_IMPORT_ERROR)

from diffsynth_engine.models.qwen_image import QwenImageDiTNunchaku
Expand All @@ -339,7 +337,7 @@ def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePip
dtype=config.model_dtype,
)
if config.use_fp8_linear and not config.use_nunchaku:
enable_fp8_linear(dit)
dit.enable_fp8_linear()

pipe = cls(
config=config,
Expand Down
3 changes: 1 addition & 2 deletions diffsynth_engine/pipelines/wan_s2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from diffsynth_engine.tokenizers import WanT5Tokenizer
from diffsynth_engine.utils.constants import WAN_TOKENIZER_CONF_PATH
from diffsynth_engine.utils.download import fetch_model
from diffsynth_engine.utils.fp8_linear import enable_fp8_linear
from diffsynth_engine.utils.image import resize_and_center_crop
from diffsynth_engine.utils.video import read_n_frames
from diffsynth_engine.utils.parallel import ParallelWrapper
Expand Down Expand Up @@ -666,7 +665,7 @@ def _from_state_dict(
use_vsa=(config.dit_attn_impl.value == "vsa"),
)
if config.use_fp8_linear:
enable_fp8_linear(dit)
dit.enable_fp8_linear()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This call to dit.enable_fp8_linear() will raise an AttributeError at runtime. The dit object, an instance of WanS2VDiT, does not have the enable_fp8_linear method defined.

To fix this, you should implement the enable_fp8_linear method in the WanS2VDiT class in diffsynth_engine/models/wan/wan_s2v_dit.py. It seems this was overlooked during the refactoring.

You can add a method similar to the one in WanDiT:

from diffsynth_engine.models.basic.lora import LoRAFP8Linear

...

class WanS2VDiT(WanDiT):
    ...
    def enable_fp8_linear(self):
        target_names = ["blocks"]
        for name, module in self.named_modules():
            if any(name.startswith(f"{t}.") for t in target_names) and isinstance(module, nn.Linear):
                self.set_submodule(name, LoRAFP8Linear.from_linear(module))

Note that you'll also need to import LoRAFP8Linear in diffsynth_engine/models/wan/wan_s2v_dit.py.


pipe = cls(
config=config,
Expand Down
Loading