From d87b20d87f3a7388ee15ba1e0e8ca943c175e95c Mon Sep 17 00:00:00 2001 From: christopher5106 Date: Tue, 10 Feb 2026 14:41:18 +0000 Subject: [PATCH 1/3] fixing text encoder lora loading --- src/diffusers/loaders/lora_conversion_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index f56632ced819..6a9ae7e86205 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -824,6 +824,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict): if not has_mixture: state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()} state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()} + state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te1_"): v for k, v in state_dict.items()} has_position_embedding = any("position_embedding" in k for k in state_dict) if has_position_embedding: @@ -856,7 +857,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict): ) state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")} - has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict) + has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_", "lora_te1_")) for k in state_dict) if has_diffb: zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b") if zero_status_diff_b: @@ -895,7 +896,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict): state_dict = { _custom_replace(k, limit_substrings): v for k, v in state_dict.items() - if k.startswith(("lora_unet_", "lora_te_")) + if k.startswith(("lora_unet_", "lora_te_", "lora_te1_")) } if any("text_projection" in k for k in state_dict): From 227d90a646268ba065993c573f59e8581ab9477c Mon Sep 17 00:00:00 2001 From: christopher5106 Date: Tue, 10 Feb 2026 17:27:31 +0000 Subject: [PATCH 2/3] following Cursor's review --- src/diffusers/loaders/lora_conversion_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 6a9ae7e86205..8b0f95b905e4 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -824,7 +824,6 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict): if not has_mixture: state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()} state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()} - state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te1_"): v for k, v in state_dict.items()} has_position_embedding = any("position_embedding" in k for k in state_dict) if has_position_embedding: From 09abb9d028ab0fb5c5febdf924b4d16792e24e27 Mon Sep 17 00:00:00 2001 From: christopher5106 Date: Wed, 11 Feb 2026 14:18:25 +0000 Subject: [PATCH 3/3] fix for loading of Kohya's Flux.2 dev lora --- .../loaders/lora_conversion_utils.py | 79 ++++++++++++++----- src/diffusers/loaders/lora_pipeline.py | 10 +++ 2 files changed, 69 insertions(+), 20 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 8b0f95b905e4..67b0b32b6260 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -357,7 +357,10 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha): # The utilities under `_convert_kohya_flux_lora_to_diffusers()` # are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py -def _convert_kohya_flux_lora_to_diffusers(state_dict): +def _convert_kohya_flux_lora_to_diffusers( + state_dict, + version_flux2 = False, +): def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): if sds_key + ".lora_down.weight" not in sds_sd: return @@ -448,7 +451,15 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): def _convert_sd_scripts_to_ai_toolkit(sds_sd): ait_sd = {} - for i in range(19): + + max_num_double_blocks, max_num_single_blocks = -1, -1 + for key in list(sds_sd.keys()): + if key.startswith("lora_unet_double_blocks_"): + max_num_double_blocks = max(max_num_double_blocks, int(key.split("_")[4])) + if key.startswith("lora_unet_single_blocks_"): + max_num_single_blocks = max(max_num_single_blocks, int(key.split("_")[4])) + + for i in range(max_num_double_blocks+1): _convert_to_ai_toolkit( sds_sd, ait_sd, @@ -469,13 +480,21 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd): sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", - f"transformer.transformer_blocks.{i}.ff.net.0.proj", + ( + f"transformer.transformer_blocks.{i}.ff.linear_in" + if version_flux2 else + f"transformer.transformer_blocks.{i}.ff.net.0.proj" + ), ) _convert_to_ai_toolkit( sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", - f"transformer.transformer_blocks.{i}.ff.net.2", + ( + f"transformer.transformer_blocks.{i}.ff.linear_out" + if version_flux2 else + f"transformer.transformer_blocks.{i}.ff.net.2" + ), ) _convert_to_ai_toolkit( sds_sd, @@ -503,13 +522,21 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd): sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", - f"transformer.transformer_blocks.{i}.ff_context.net.0.proj", + ( + f"transformer.transformer_blocks.{i}.ff_context.linear_in" + if version_flux2 else + f"transformer.transformer_blocks.{i}.ff_context.net.0.proj" + ), ) _convert_to_ai_toolkit( sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", - f"transformer.transformer_blocks.{i}.ff_context.net.2", + ( + f"transformer.transformer_blocks.{i}.ff_context.linear_out" + if version_flux2 else + f"transformer.transformer_blocks.{i}.ff_context.net.2" + ), ) _convert_to_ai_toolkit( sds_sd, @@ -518,24 +545,36 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd): f"transformer.transformer_blocks.{i}.norm1_context.linear", ) - for i in range(38): - _convert_to_ai_toolkit_cat( - sds_sd, - ait_sd, - f"lora_unet_single_blocks_{i}_linear1", - [ - f"transformer.single_transformer_blocks.{i}.attn.to_q", - f"transformer.single_transformer_blocks.{i}.attn.to_k", - f"transformer.single_transformer_blocks.{i}.attn.to_v", - f"transformer.single_transformer_blocks.{i}.proj_mlp", - ], - dims=[3072, 3072, 3072, 12288], - ) + for i in range(max_num_single_blocks+1): + if version_flux2: + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + f"transformer.single_transformer_blocks.{i}.attn.to_qkv_mlp_proj", + ) + else: + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + dims=[3072, 3072, 3072, 12288], + ) _convert_to_ai_toolkit( sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", - f"transformer.single_transformer_blocks.{i}.proj_out", + ( + f"transformer.single_transformer_blocks.{i}.attn.to_out" + if version_flux2 else + f"transformer.single_transformer_blocks.{i}.proj_out" + ), ) _convert_to_ai_toolkit( sds_sd, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index e92eb8ba396d..5a2cfcfb24f0 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5472,6 +5472,16 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + is_kohya = any(".lora_down.weight" in k for k in state_dict) + if is_kohya: + state_dict = _convert_kohya_flux_lora_to_diffusers( + state_dict, + version_flux2=True, + ) + # Kohya already takes care of scaling the LoRA parameters with alpha. + for k in state_dict: + assert "alpha" not in k, f"Found key with alpha: {k}" + is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict) if is_ai_toolkit: state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)