Skip to content
Merged
4 changes: 3 additions & 1 deletion invokeai/app/invocations/cogview4_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
CogView4ConditioningInfo,
ConditioningFieldData,
)
from invokeai.backend.util.devices import TorchDevice

Check failure on line 13 in invokeai/app/invocations/cogview4_text_encoder.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (F401)

invokeai/app/invocations/cogview4_text_encoder.py:13:43: F401 `invokeai.backend.util.devices.TorchDevice` imported but unused

# The CogView4 GLM Text Encoder max sequence length set based on the default in diffusers.
COGVIEW4_GLM_MAX_SEQ_LEN = 1024
Expand All @@ -37,6 +37,8 @@
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CogView4ConditioningOutput:
glm_embeds = self._glm_encode(context, max_seq_len=COGVIEW4_GLM_MAX_SEQ_LEN)
# Move embeddings to CPU for storage to save VRAM
glm_embeds = glm_embeds.detach().to("cpu")
conditioning_data = ConditioningFieldData(conditionings=[CogView4ConditioningInfo(glm_embeds=glm_embeds)])
conditioning_name = context.conditioning.save(conditioning_data)
return CogView4ConditioningOutput.build(conditioning_name)
Expand Down Expand Up @@ -85,7 +87,7 @@
)
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
prompt_embeds = glm_text_encoder(
text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True
text_input_ids.to(glm_text_encoder.device), output_hidden_states=True
).hidden_states[-2]

assert isinstance(prompt_embeds, torch.Tensor)
Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
device=TorchDevice.choose_torch_device(),
device=text_encoder.device, # Use the device the model is actually on
split_long_text_mode=SplitLongTextMode.SENTENCES,
)

Expand Down Expand Up @@ -212,7 +212,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
device=TorchDevice.choose_torch_device(),
device=text_encoder.device, # Use the device the model is actually on
split_long_text_mode=SplitLongTextMode.SENTENCES,
)

Expand Down
6 changes: 6 additions & 0 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
t5_embeddings = self._t5_encode(context)
clip_embeddings = self._clip_encode(context)

Check failure on line 61 in invokeai/app/invocations/flux_text_encoder.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (W293)

invokeai/app/invocations/flux_text_encoder.py:61:1: W293 Blank line contains whitespace
# Move embeddings to CPU for storage to save VRAM
# They will be moved to the appropriate device when used by the denoiser
t5_embeddings = t5_embeddings.detach().to("cpu")
clip_embeddings = clip_embeddings.detach().to("cpu")

Check failure on line 66 in invokeai/app/invocations/flux_text_encoder.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (W293)

invokeai/app/invocations/flux_text_encoder.py:66:1: W293 Blank line contains whitespace
conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
)
Expand Down
13 changes: 11 additions & 2 deletions invokeai/app/invocations/sd3_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
from invokeai.backend.util.devices import TorchDevice

Check failure on line 24 in invokeai/app/invocations/sd3_text_encoder.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (F401)

invokeai/app/invocations/sd3_text_encoder.py:24:43: F401 `invokeai.backend.util.devices.TorchDevice` imported but unused

# The SD3 T5 Max Sequence Length set based on the default in diffusers.
SD3_T5_MAX_SEQ_LEN = 256
Expand Down Expand Up @@ -69,6 +69,15 @@
if self.t5_encoder is not None:
t5_embeddings = self._t5_encode(context, SD3_T5_MAX_SEQ_LEN)

# Move all embeddings to CPU for storage to save VRAM
# They will be moved to the appropriate device when used by the denoiser
clip_l_embeddings = clip_l_embeddings.detach().to("cpu")
clip_l_pooled_embeddings = clip_l_pooled_embeddings.detach().to("cpu")
clip_g_embeddings = clip_g_embeddings.detach().to("cpu")
clip_g_pooled_embeddings = clip_g_pooled_embeddings.detach().to("cpu")
if t5_embeddings is not None:
t5_embeddings = t5_embeddings.detach().to("cpu")

conditioning_data = ConditioningFieldData(
conditionings=[
SD3ConditioningInfo(
Expand Down Expand Up @@ -117,7 +126,7 @@
f" {max_seq_len} tokens: {removed_text}"
)

prompt_embeds = t5_text_encoder(text_input_ids.to(TorchDevice.choose_torch_device()))[0]
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]

assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
Expand Down Expand Up @@ -180,7 +189,7 @@
f" {tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = clip_text_encoder(
input_ids=text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
Expand Down
6 changes: 5 additions & 1 deletion invokeai/app/invocations/z_image_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class ZImageTextEncoderInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ZImageConditioningOutput:
prompt_embeds = self._encode_prompt(context, max_seq_len=Z_IMAGE_MAX_SEQ_LEN)
# Move embeddings to CPU for storage to save VRAM
prompt_embeds = prompt_embeds.detach().to("cpu")
conditioning_data = ConditioningFieldData(conditionings=[ZImageConditioningInfo(prompt_embeds=prompt_embeds)])
conditioning_name = context.conditioning.save(conditioning_data)
return ZImageConditioningOutput(
Expand All @@ -69,7 +71,6 @@ def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.
Based on the ZImagePipeline._encode_prompt method from diffusers.
"""
prompt = self.prompt
device = TorchDevice.choose_torch_device()

text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder)
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
Expand All @@ -78,6 +79,9 @@ def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.
(_, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())

# Use the device that the text_encoder is actually on
device = text_encoder.device

# Apply LoRA models to the text encoder
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
exit_stack.enter_context(
Expand Down
3 changes: 3 additions & 0 deletions invokeai/backend/model_manager/configs/qwen3_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class Qwen3Encoder_Checkpoint_Config(Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only")

@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
Expand Down Expand Up @@ -87,6 +88,7 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
format: Literal[ModelFormat.Qwen3Encoder] = Field(default=ModelFormat.Qwen3Encoder)
cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only")

@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
Expand Down Expand Up @@ -130,6 +132,7 @@ class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only")

@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
Expand Down
2 changes: 2 additions & 0 deletions invokeai/backend/model_manager/configs/t5_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class T5Encoder_T5Encoder_Config(Config_Base):
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder)
cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only")

@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
Expand Down Expand Up @@ -50,6 +51,7 @@ class T5Encoder_BnBLLMint8_Config(Config_Base):
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b)
cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only")

@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
Expand Down
22 changes: 16 additions & 6 deletions invokeai/backend/model_manager/load/load_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,26 @@
model_base = self._app_config.models_path
return (model_base / config.path).resolve()

def _get_execution_device(self, config: AnyModelConfig) -> Optional[torch.device]:
def _get_execution_device(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> Optional[torch.device]:
"""Determine the execution device for a model based on its configuration.


Check failure on line 73 in invokeai/backend/model_manager/load/load_default.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (W293)

invokeai/backend/model_manager/load/load_default.py:73:1: W293 Blank line contains whitespace
CPU-only execution is only applied to text encoder submodels to save VRAM while keeping
the denoiser on GPU for performance. Conditioning tensors are moved to GPU after encoding.

Check failure on line 76 in invokeai/backend/model_manager/load/load_default.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (W293)

invokeai/backend/model_manager/load/load_default.py:76:1: W293 Blank line contains whitespace
Returns:
torch.device("cpu") if the model should run on CPU only, None otherwise (use cache default).
"""
# Check if this is a main model with default settings that specify cpu_only
# Check if this is a text encoder submodel of a main model with cpu_only setting
if hasattr(config, "default_settings") and config.default_settings is not None:
if hasattr(config.default_settings, "cpu_only") and config.default_settings.cpu_only is True:
return torch.device("cpu")
# Only apply CPU execution to text encoder submodels
if submodel_type in [SubModelType.TextEncoder, SubModelType.TextEncoder2, SubModelType.TextEncoder3]:
return torch.device("cpu")

Check failure on line 86 in invokeai/backend/model_manager/load/load_default.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (W293)

invokeai/backend/model_manager/load/load_default.py:86:1: W293 Blank line contains whitespace
# Check if this is a standalone text encoder config with cpu_only field (T5Encoder, Qwen3Encoder, etc.)
if hasattr(config, "cpu_only") and config.cpu_only is True:
return torch.device("cpu")

Check failure on line 90 in invokeai/backend/model_manager/load/load_default.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (W293)

invokeai/backend/model_manager/load/load_default.py:90:1: W293 Blank line contains whitespace
return None

def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> CacheRecord:
Expand All @@ -91,8 +101,8 @@
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
loaded_model = self._load_model(config, submodel_type)

# Determine execution device from model config
execution_device = self._get_execution_device(config)
# Determine execution device from model config, considering submodel type
execution_device = self._get_execution_device(config, submodel_type)

self._ram_cache.put(
get_model_cache_key(config.key, submodel_type),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def is_in_vram(self) -> bool:
"""Return true if the model is currently in VRAM."""
return self._is_in_vram

@property
def compute_device(self) -> torch.device:
"""Return the compute device for this model."""
return self._compute_device

def full_load_to_vram(self) -> int:
"""Load all weights into VRAM (if supported by the model).
Returns:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ def cur_vram_bytes(self) -> int:
)
return self._cur_vram_bytes

@property
def compute_device(self) -> torch.device:
"""Return the compute device for this model."""
return self._compute_device

def full_load_to_vram(self) -> int:
"""Load all weights into VRAM."""
return self.partial_load_to_vram(self.total_bytes())
Expand Down
13 changes: 10 additions & 3 deletions invokeai/backend/model_manager/load/model_cache/model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,13 @@ def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> No
f"Locking model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
)

if self._execution_device.type == "cpu":
# Models don't need to be loaded into VRAM if we're running on CPU.
# Check if the model's specific compute_device is CPU, not just the cache's default execution_device
model_compute_device = cache_entry.cached_model.compute_device
if model_compute_device.type == "cpu":
# Models configured for CPU execution don't need to be loaded into VRAM
self._logger.debug(
f"Model {cache_entry.key} is configured for CPU execution, skipping VRAM load"
)
return

try:
Expand Down Expand Up @@ -521,9 +526,11 @@ def _load_locked_model(self, cache_entry: CacheRecord, working_mem_bytes: Option
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
vram_available = self._get_vram_available(working_mem_bytes)
loaded_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0
# Use the model's actual compute_device for logging, not the cache's default
model_device = cache_entry.cached_model.compute_device
self._logger.info(
f"Loaded model '{cache_entry.key}' ({cache_entry.cached_model.model.__class__.__name__}) onto "
f"{self._execution_device.type} device in {(time.time() - start_time):.2f}s. "
f"{model_device.type} device in {(time.time() - start_time):.2f}s. "
f"Total model size: {model_total_bytes / MB:.2f}MB, "
f"VRAM: {model_cur_vram_bytes / MB:.2f}MB ({loaded_percent:.1%})"
)
Expand Down
Loading
Loading