66from PIL import Image
77
88from diffsynth_engine .configs import BaseConfig , BaseStateDicts
9- from diffsynth_engine .utils .offload import enable_sequential_cpu_offload
9+ from diffsynth_engine .utils .offload import enable_sequential_cpu_offload , offload_model_to_dict , restore_model_from_dict
1010from diffsynth_engine .utils .fp8_linear import enable_fp8_autocast
1111from diffsynth_engine .utils .gguf import load_gguf_checkpoint
1212from diffsynth_engine .utils import logging
@@ -40,6 +40,7 @@ def __init__(
4040 self .dtype = dtype
4141 self .offload_mode = None
4242 self .model_names = []
43+ self ._offload_param_dict = {}
4344
4445 @classmethod
4546 def from_pretrained (cls , model_path_or_config : str | BaseConfig ) -> "BasePipeline" :
@@ -243,14 +244,13 @@ def _enable_model_cpu_offload(self):
243244 for model_name in self .model_names :
244245 model = getattr (self , model_name )
245246 if model is not None :
246- model . to ( "cpu" )
247+ self . _offload_param_dict [ model_name ] = offload_model_to_dict ( model )
247248 self .offload_mode = "cpu_offload"
248249
249250 def _enable_sequential_cpu_offload (self ):
250251 for model_name in self .model_names :
251252 model = getattr (self , model_name )
252253 if model is not None :
253- model .to ("cpu" )
254254 enable_sequential_cpu_offload (model , self .device )
255255 self .offload_mode = "sequential_cpu_offload"
256256
@@ -277,20 +277,12 @@ def load_models_to_device(self, load_model_names: List[str] | None = None):
277277 for model_name in self .model_names :
278278 if model_name not in load_model_names :
279279 model = getattr (self , model_name )
280- if (
281- model is not None
282- and (p := next (model .parameters (), None )) is not None
283- and p .device != torch .device ("cpu" )
284- ):
285- model .to ("cpu" )
280+ if model is not None and (p := next (model .parameters (), None )) is not None and p .device .type != "cpu" :
281+ restore_model_from_dict (model , self ._offload_param_dict [model_name ])
286282 # load the needed models to device
287283 for model_name in load_model_names :
288284 model = getattr (self , model_name )
289- if (
290- model is not None
291- and (p := next (model .parameters (), None )) is not None
292- and p .device != torch .device (self .device )
293- ):
285+ if model is not None and (p := next (model .parameters (), None )) is not None and p .device .type != self .device :
294286 model .to (self .device )
295287 # fresh the cuda cache
296288 empty_cache ()
0 commit comments