From 74e8770ca19923b2630d85602ba40c00e8cd5519 Mon Sep 17 00:00:00 2001 From: Akari Date: Tue, 10 Feb 2026 15:26:17 +0100 Subject: [PATCH] fix: handle remaining meta tensors in from_single_file before dispatch When loading models from single-file checkpoints (e.g., GGUF format), some parameters or buffers may not be present in the checkpoint and remain on the meta device after load_model_dict_into_meta. This causes dispatch_model to fail with 'Cannot copy out of meta tensor' errors. This fix adds three layers of handling before dispatch_model(): 1. **Non-persistent buffer re-materialization**: Identifies submodules with non-persistent meta buffers (e.g., RoPE sinusoidal embeddings in WanTransformer3DModel) and re-creates them outside the init_empty_weights context, allowing deterministic buffers to be properly computed. 2. **Persistent buffer fallback**: Zero-initializes any persistent buffers still on meta device with a warning. 3. **Parameter safety net**: Zero-initializes any parameters still on meta device (indicates incomplete key mapping in the checkpoint converter) with a warning. Fixes #12009 --- src/diffusers/loaders/single_file_model.py | 98 ++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index ab4340fed120..2f38719a7475 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -524,6 +524,104 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = model.eval() if device_map is not None: + # Fix: handle remaining meta tensors before dispatch. + # When loading from single-file checkpoints (e.g., GGUF), some model parameters + # or buffers may not be present in the checkpoint and remain on the meta device + # after `load_model_dict_into_meta`. This causes `dispatch_model` to fail with + # "Cannot copy out of meta tensor" errors. + if is_accelerate_available(): + from accelerate.utils import set_module_tensor_to_device + + _target_device = list(device_map.values())[0] if device_map else "cpu" + + # Collect submodules that have non-persistent buffers still on meta. + # These buffers are computed during __init__ (e.g., RoPE sinusoidal embeddings) + # and were never saved in the checkpoint. We need to re-materialize them by + # re-creating the submodule outside of init_empty_weights context. + _modules_to_reinit = set() + for _name, _buf in list(model.named_buffers()): + if _buf.device == torch.device("meta"): + _parts = _name.rsplit(".", 1) + _parent_path = _parts[0] + _buf_name = _parts[-1] + _submodule = model.get_submodule(_parent_path) + _is_persistent = _buf_name not in _submodule._non_persistent_buffers_set + if not _is_persistent: + _modules_to_reinit.add(_parent_path) + else: + logger.warning( + f"Buffer '{_name}' is still on meta device after loading. " + f"Initializing to zeros." + ) + set_module_tensor_to_device( + model, _name, _target_device, + value=torch.zeros( + _buf.shape, dtype=_buf.dtype, device=_target_device + ), + ) + + # Re-create submodules with non-persistent meta buffers. + # This recomputes deterministic buffers (e.g., sinusoidal positional embeddings) + # that were created as meta tensors under init_empty_weights(). + for _mod_path in _modules_to_reinit: + _submodule = model.get_submodule(_mod_path) + _parent_path, _child_name = _mod_path.rsplit(".", 1) if "." in _mod_path else ("", _mod_path) + _parent = model.get_submodule(_parent_path) if _parent_path else model + _cls = type(_submodule) + + # Reconstruct using config attributes stored on the submodule + _init_args = {} + import inspect + _sig = inspect.signature(_cls.__init__) + for _param_name, _param in _sig.parameters.items(): + if _param_name == "self": + continue + if hasattr(_submodule, _param_name): + _init_args[_param_name] = getattr(_submodule, _param_name) + elif _param.default is not inspect.Parameter.empty: + pass # use default + else: + break # can't reconstruct, fall back below + else: + try: + logger.info( + f"Re-creating submodule '{_mod_path}' ({_cls.__name__}) " + f"to materialize non-persistent buffers." + ) + _new_submodule = _cls(**_init_args) + setattr(_parent, _child_name, _new_submodule) + continue + except Exception as e: + logger.warning( + f"Failed to re-create '{_mod_path}': {e}. " + f"Falling back to zero-initialization." + ) + + # Fallback: zero-init any remaining meta buffers in this submodule + for _name, _buf in list(model.named_buffers()): + if _name.startswith(_mod_path) and _buf.device == torch.device("meta"): + set_module_tensor_to_device( + model, _name, _target_device, + value=torch.zeros( + _buf.shape, dtype=_buf.dtype, device=_target_device + ), + ) + + # Handle any remaining meta parameters (should not happen with correct key mapping) + for _name, _param in list(model.named_parameters()): + if _param.device == torch.device("meta"): + logger.warning( + f"Parameter '{_name}' is still on meta device after loading. " + f"This likely indicates an incomplete checkpoint key mapping. " + f"Initializing to zeros." + ) + set_module_tensor_to_device( + model, _name, _target_device, + value=torch.zeros( + _param.shape, dtype=_param.dtype, device=_target_device + ), + ) + device_map_kwargs = {"device_map": device_map} dispatch_model(model, **device_map_kwargs)