Skip to content

Conversation

@yurekami
Copy link
Contributor

@yurekami yurekami commented Jan 2, 2026

Summary

  • Fix issue where LoRA/PEFT models saved with hf_model checkpoint content had incorrect state_dict key structure
  • Add _save_lora_adapter() method to extract LoRA weights and normalize keys to standard HuggingFace format
  • Saved checkpoints are now compatible with vLLM and HuggingFace inference engines

Problem

When saving checkpoints with save_contents=['model','optimizer','extra','hf_model'], LoRA-trained models were saved with:

  • base_model.model.lm_head.weight instead of lm_head.weight
  • base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight instead of model.layers.0.self_attn.q_proj.weight
  • Mixed lora_A.default.weight and lora_B.default.weight keys

This made the saved HuggingFace checkpoints incompatible with inference engines.

Solution

Added _save_lora_adapter() method that:

  1. Detects LoRA weights in state_dict (keys containing lora_)
  2. Extracts LoRA weights and saves them as a PEFT-compatible adapter in huggingface/lora_adapter/
  3. Normalizes remaining keys by removing base_model.model. prefix and .base_layer suffix

Test plan

  • Unit tests added in tests/utils/test_lora_checkpoint_on_cpu.py
  • Tests verify key normalization logic
  • Tests verify non-LoRA models pass through unchanged
  • Ruff linting passes
  • E2E test with actual LoRA training (requires GPU)

Fixes #4761

🤖 Generated with Claude Code

When saving HuggingFace models with `hf_model` in checkpoint contents,
LoRA/PEFT models were saved with incorrect key structure:
- Keys had "base_model.model." prefix instead of standard format
- LoRA weights (lora_A, lora_B) were mixed with base model weights
- ".base_layer.weight" suffix was preserved instead of ".weight"

This caused the saved checkpoints to be incompatible with vLLM and
HuggingFace inference engines that expect standard model format.

The fix adds _save_lora_adapter() method that:
1. Detects LoRA weights in state_dict
2. Extracts and saves them as a separate PEFT-compatible adapter
3. Normalizes remaining keys to standard HuggingFace format

Fixes volcengine#4761

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fix for saving LoRA/PEFT models by normalizing the state_dict keys to be compatible with HuggingFace inference engines. A new _save_lora_adapter method is added to handle the extraction of LoRA weights and key normalization, and unit tests are provided to verify this new functionality.

The changes are logical and well-tested. I have two main points of feedback:

  1. The key normalization logic in _save_lora_adapter uses str.replace(), which is not robust for prefix/suffix removal and could lead to bugs with unexpectedly formatted keys. I've suggested a more robust implementation using startswith() and endswith().
  2. The unit tests duplicate the implementation of _save_lora_adapter. This creates a maintainability risk. I've recommended converting the method to a @staticmethod to allow direct testing without code duplication.

Addressing these points will improve the robustness and maintainability of the code.

Comment on lines 69 to 117
def _save_lora_adapter_standalone(self, state_dict: dict, target_dir: str):
"""
Standalone implementation of _save_lora_adapter for testing.
This mirrors the implementation in FSDPCheckpointManager.
"""
lora_params_names = [name for name in state_dict.keys() if "lora_" in name]
if len(lora_params_names) == 0:
return None

import peft
from safetensors.torch import save_file

lora_params = OrderedDict()
target_modules = set()
lora_key = None

for name in lora_params_names:
lora_key = name.replace(".default.weight", ".weight")
target_modules.add(lora_key.split(".")[-3])
lora_params[lora_key] = state_dict.pop(name)

lora_rank = min(lora_params[lora_key].shape[0], lora_params[lora_key].shape[1])
peft_dict = {
"r": lora_rank,
"lora_alpha": 0,
"target_modules": list(target_modules),
}
peft_config = peft.LoraConfig(**peft_dict).to_dict()
peft_config["task_type"] = peft_config["task_type"].value if peft_config["task_type"] else None
peft_config["peft_type"] = peft_config["peft_type"].value if peft_config["peft_type"] else None
peft_config["target_modules"] = list(peft_config["target_modules"])

lora_path = os.path.join(target_dir, "lora_adapter")
os.makedirs(lora_path, exist_ok=True)
with open(os.path.join(lora_path, "adapter_config.json"), "w", encoding="utf-8") as f:
json.dump(peft_config, f, ensure_ascii=False, indent=4)
save_file(lora_params, os.path.join(lora_path, "adapter_model.safetensors"))

# Normalize remaining keys to standard HuggingFace format
for name in list(state_dict.keys()):
key = (
name.replace("base_model.model.", "")
.replace(".base_layer.weight", ".weight")
.replace(".base_layer.bias", ".bias")
)
if key != name:
state_dict[key] = state_dict.pop(name)

return lora_path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function _save_lora_adapter_standalone is a copy of the _save_lora_adapter method from FSDPCheckpointManager. This code duplication creates a maintainability issue: if the original method is changed in the future, these tests will be testing outdated logic unless this copy is also manually updated.

The _save_lora_adapter method does not use any instance attributes (self), so it can be converted to a @staticmethod. This would allow you to call it directly from your test file, eliminating the need for code duplication.

I recommend the following changes:

  1. In verl/utils/checkpoint/fsdp_checkpoint_manager.py, add the @staticmethod decorator to _save_lora_adapter.
    @staticmethod
    def _save_lora_adapter(state_dict: dict[str, torch.Tensor], target_dir: str) -> Optional[str]:
        # ... method body ...
  2. In this test file, remove the _save_lora_adapter_standalone function.
  3. Import FSDPCheckpointManager and call the static method directly in your tests, for example:
    from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
    ...
    lora_path = FSDPCheckpointManager._save_lora_adapter(state_dict, temp_dir)

This will make the tests more robust and easier to maintain.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yurekami Could you resolve the concerns raised by Gemini?

Comment on lines 154 to 158
name.replace("base_model.model.", "")
.replace(".base_layer.weight", ".weight")
.replace(".base_layer.bias", ".bias")
)
if key != name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of chained str.replace() for normalizing state dictionary keys is not robust. The PR description mentions removing a "prefix", but replace() will substitute all occurrences of the substring, not just at the beginning or end of the string. This could lead to unexpected behavior and corrupted keys if the substrings ("base_model.model.", ".base_layer.weight", ".base_layer.bias") appear in other parts of a parameter name.

To make this logic more robust and align with the intent of removing prefixes/suffixes, please consider using startswith() and endswith() checks.

            key = name
            if key.startswith("base_model.model."):
                key = key[len("base_model.model.") :]

            if key.endswith(".base_layer.weight"):
                key = key[: -len(".base_layer.weight")] + ".weight"
            elif key.endswith(".base_layer.bias"):
                key = key[: -len(".base_layer.bias")] + ".bias"

…alization

- Replace chained str.replace() calls with explicit startswith/endswith
  checks to avoid unintended substitutions if patterns appear in middle
  of key names
- Convert _save_lora_adapter to @staticmethod to eliminate code
  duplication in tests
- Update tests to use the actual implementation instead of a duplicated
  standalone copy

Addresses review feedback from PR volcengine#4770.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@yurekami
Copy link
Contributor Author

yurekami commented Jan 3, 2026

Thanks for the review feedback! I've addressed both concerns in commit 3952342:

1. Unsafe String Replacement (Fixed)

Replaced the chained str.replace() calls with explicit startswith()/endswith() checks:

# Before (unsafe - replaces ALL occurrences)
key = name.replace("base_model.model.", "").replace(".base_layer.weight", ".weight")

# After (safe - only handles prefix/suffix)
if key.startswith("base_model.model."):
    key = key[len("base_model.model."):]
if key.endswith(".base_layer.weight"):
    key = key[:-len(".base_layer.weight")] + ".weight"

This prevents unintended substitutions if patterns appear in the middle of key names.

2. Code Duplication (Fixed)

  • Converted _save_lora_adapter() to a @staticmethod
  • Removed the duplicated _save_lora_adapter_standalone() from the test file (50 lines deleted)
  • Tests now call FSDPCheckpointManager._save_lora_adapter() directly

All tests pass ✅

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] no merged model under the folder huggingface

3 participants