Skip to content

Commit 7a6d86d

Browse files
authored
speedup model cpu offload (#136)
* speedup model cpu offload * fix
1 parent e259c49 commit 7a6d86d

File tree

4 files changed

+40
-14
lines changed

4 files changed

+40
-14
lines changed

diffsynth_engine/models/qwen_image/qwen_image_dit.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def forward(
315315

316316
class QwenImageDiT(PreTrainedModel):
317317
converter = QwenImageDiTStateDictConverter()
318+
_supports_parallelization = True
318319

319320
def __init__(
320321
self,
@@ -423,3 +424,6 @@ def from_state_dict(
423424
model.load_state_dict(state_dict, assign=True)
424425
model.to(device=device, dtype=dtype, non_blocking=True)
425426
return model
427+
428+
def get_fsdp_modules(self):
429+
return ["transformer_blocks"]

diffsynth_engine/pipelines/base.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from PIL import Image
77

88
from diffsynth_engine.configs import BaseConfig, BaseStateDicts
9-
from diffsynth_engine.utils.offload import enable_sequential_cpu_offload
9+
from diffsynth_engine.utils.offload import enable_sequential_cpu_offload, offload_model_to_dict, restore_model_from_dict
1010
from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
1111
from diffsynth_engine.utils.gguf import load_gguf_checkpoint
1212
from diffsynth_engine.utils import logging
@@ -40,6 +40,7 @@ def __init__(
4040
self.dtype = dtype
4141
self.offload_mode = None
4242
self.model_names = []
43+
self._offload_param_dict = {}
4344

4445
@classmethod
4546
def from_pretrained(cls, model_path_or_config: str | BaseConfig) -> "BasePipeline":
@@ -243,14 +244,13 @@ def _enable_model_cpu_offload(self):
243244
for model_name in self.model_names:
244245
model = getattr(self, model_name)
245246
if model is not None:
246-
model.to("cpu")
247+
self._offload_param_dict[model_name] = offload_model_to_dict(model)
247248
self.offload_mode = "cpu_offload"
248249

249250
def _enable_sequential_cpu_offload(self):
250251
for model_name in self.model_names:
251252
model = getattr(self, model_name)
252253
if model is not None:
253-
model.to("cpu")
254254
enable_sequential_cpu_offload(model, self.device)
255255
self.offload_mode = "sequential_cpu_offload"
256256

@@ -277,20 +277,12 @@ def load_models_to_device(self, load_model_names: List[str] | None = None):
277277
for model_name in self.model_names:
278278
if model_name not in load_model_names:
279279
model = getattr(self, model_name)
280-
if (
281-
model is not None
282-
and (p := next(model.parameters(), None)) is not None
283-
and p.device != torch.device("cpu")
284-
):
285-
model.to("cpu")
280+
if model is not None and (p := next(model.parameters(), None)) is not None and p.device.type != "cpu":
281+
restore_model_from_dict(model, self._offload_param_dict[model_name])
286282
# load the needed models to device
287283
for model_name in load_model_names:
288284
model = getattr(self, model_name)
289-
if (
290-
model is not None
291-
and (p := next(model.parameters(), None)) is not None
292-
and p.device != torch.device(self.device)
293-
):
285+
if model is not None and (p := next(model.parameters(), None)) is not None and p.device.type != self.device:
294286
model.to(self.device)
295287
# fresh the cuda cache
296288
empty_cache()

diffsynth_engine/pipelines/wan_video.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,4 +584,11 @@ def from_pretrained(cls, model_path_or_config: WanPipelineConfig) -> "WanVideoPi
584584
use_fsdp=config.use_fsdp,
585585
device="cuda",
586586
)
587+
if config.use_torch_compile:
588+
pipe.compile()
587589
return pipe
590+
591+
def compile(self):
592+
self.dit.compile()
593+
if self.dit2 is not None:
594+
self.dit2.compile()

diffsynth_engine/utils/offload.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import torch
22
import torch.nn as nn
3+
from typing import Dict
34

45

56
def enable_sequential_cpu_offload(module: nn.Module, device: str = "cuda"):
7+
module = module.to("cpu")
68
if len(list(module.children())) == 0:
79
if len(list(module.parameters())) > 0 or len(list(module.buffers())) > 0:
810
# leaf module with parameters or buffers
@@ -50,3 +52,24 @@ def _forward_hook(module: nn.Module, input_, output_):
5052
module.register_forward_pre_hook(_forward_pre_hook)
5153
module.register_forward_hook(_forward_hook)
5254
setattr(module, "_cpu_offload_enabled", True)
55+
56+
57+
def offload_model_to_dict(module: nn.Module) -> Dict[str, torch.Tensor]:
58+
module = module.to("cpu")
59+
offload_param_dict = {}
60+
for name, param in module.named_parameters(recurse=True):
61+
param.data = param.data.pin_memory()
62+
offload_param_dict[name] = param.data
63+
for name, buffer in module.named_buffers(recurse=True):
64+
buffer.data = buffer.data.pin_memory()
65+
offload_param_dict[name] = buffer.data
66+
return offload_param_dict
67+
68+
69+
def restore_model_from_dict(module: nn.Module, offload_param_dict: Dict[str, torch.Tensor]):
70+
for name, param in module.named_parameters(recurse=True):
71+
if name in offload_param_dict:
72+
param.data = offload_param_dict[name]
73+
for name, buffer in module.named_buffers(recurse=True):
74+
if name in offload_param_dict:
75+
buffer.data = offload_param_dict[name]

0 commit comments

Comments
 (0)