-
Notifications
You must be signed in to change notification settings - Fork 3k
[checkpoint] fix: normalize LoRA state_dict keys when saving HF model #4770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
When saving HuggingFace models with `hf_model` in checkpoint contents, LoRA/PEFT models were saved with incorrect key structure: - Keys had "base_model.model." prefix instead of standard format - LoRA weights (lora_A, lora_B) were mixed with base model weights - ".base_layer.weight" suffix was preserved instead of ".weight" This caused the saved checkpoints to be incompatible with vLLM and HuggingFace inference engines that expect standard model format. The fix adds _save_lora_adapter() method that: 1. Detects LoRA weights in state_dict 2. Extracts and saves them as a separate PEFT-compatible adapter 3. Normalizes remaining keys to standard HuggingFace format Fixes volcengine#4761 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a fix for saving LoRA/PEFT models by normalizing the state_dict keys to be compatible with HuggingFace inference engines. A new _save_lora_adapter method is added to handle the extraction of LoRA weights and key normalization, and unit tests are provided to verify this new functionality.
The changes are logical and well-tested. I have two main points of feedback:
- The key normalization logic in
_save_lora_adapterusesstr.replace(), which is not robust for prefix/suffix removal and could lead to bugs with unexpectedly formatted keys. I've suggested a more robust implementation usingstartswith()andendswith(). - The unit tests duplicate the implementation of
_save_lora_adapter. This creates a maintainability risk. I've recommended converting the method to a@staticmethodto allow direct testing without code duplication.
Addressing these points will improve the robustness and maintainability of the code.
| def _save_lora_adapter_standalone(self, state_dict: dict, target_dir: str): | ||
| """ | ||
| Standalone implementation of _save_lora_adapter for testing. | ||
| This mirrors the implementation in FSDPCheckpointManager. | ||
| """ | ||
| lora_params_names = [name for name in state_dict.keys() if "lora_" in name] | ||
| if len(lora_params_names) == 0: | ||
| return None | ||
|
|
||
| import peft | ||
| from safetensors.torch import save_file | ||
|
|
||
| lora_params = OrderedDict() | ||
| target_modules = set() | ||
| lora_key = None | ||
|
|
||
| for name in lora_params_names: | ||
| lora_key = name.replace(".default.weight", ".weight") | ||
| target_modules.add(lora_key.split(".")[-3]) | ||
| lora_params[lora_key] = state_dict.pop(name) | ||
|
|
||
| lora_rank = min(lora_params[lora_key].shape[0], lora_params[lora_key].shape[1]) | ||
| peft_dict = { | ||
| "r": lora_rank, | ||
| "lora_alpha": 0, | ||
| "target_modules": list(target_modules), | ||
| } | ||
| peft_config = peft.LoraConfig(**peft_dict).to_dict() | ||
| peft_config["task_type"] = peft_config["task_type"].value if peft_config["task_type"] else None | ||
| peft_config["peft_type"] = peft_config["peft_type"].value if peft_config["peft_type"] else None | ||
| peft_config["target_modules"] = list(peft_config["target_modules"]) | ||
|
|
||
| lora_path = os.path.join(target_dir, "lora_adapter") | ||
| os.makedirs(lora_path, exist_ok=True) | ||
| with open(os.path.join(lora_path, "adapter_config.json"), "w", encoding="utf-8") as f: | ||
| json.dump(peft_config, f, ensure_ascii=False, indent=4) | ||
| save_file(lora_params, os.path.join(lora_path, "adapter_model.safetensors")) | ||
|
|
||
| # Normalize remaining keys to standard HuggingFace format | ||
| for name in list(state_dict.keys()): | ||
| key = ( | ||
| name.replace("base_model.model.", "") | ||
| .replace(".base_layer.weight", ".weight") | ||
| .replace(".base_layer.bias", ".bias") | ||
| ) | ||
| if key != name: | ||
| state_dict[key] = state_dict.pop(name) | ||
|
|
||
| return lora_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function _save_lora_adapter_standalone is a copy of the _save_lora_adapter method from FSDPCheckpointManager. This code duplication creates a maintainability issue: if the original method is changed in the future, these tests will be testing outdated logic unless this copy is also manually updated.
The _save_lora_adapter method does not use any instance attributes (self), so it can be converted to a @staticmethod. This would allow you to call it directly from your test file, eliminating the need for code duplication.
I recommend the following changes:
- In
verl/utils/checkpoint/fsdp_checkpoint_manager.py, add the@staticmethoddecorator to_save_lora_adapter.@staticmethod def _save_lora_adapter(state_dict: dict[str, torch.Tensor], target_dir: str) -> Optional[str]: # ... method body ...
- In this test file, remove the
_save_lora_adapter_standalonefunction. - Import
FSDPCheckpointManagerand call the static method directly in your tests, for example:from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager ... lora_path = FSDPCheckpointManager._save_lora_adapter(state_dict, temp_dir)
This will make the tests more robust and easier to maintain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yurekami Could you resolve the concerns raised by Gemini?
| name.replace("base_model.model.", "") | ||
| .replace(".base_layer.weight", ".weight") | ||
| .replace(".base_layer.bias", ".bias") | ||
| ) | ||
| if key != name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of chained str.replace() for normalizing state dictionary keys is not robust. The PR description mentions removing a "prefix", but replace() will substitute all occurrences of the substring, not just at the beginning or end of the string. This could lead to unexpected behavior and corrupted keys if the substrings ("base_model.model.", ".base_layer.weight", ".base_layer.bias") appear in other parts of a parameter name.
To make this logic more robust and align with the intent of removing prefixes/suffixes, please consider using startswith() and endswith() checks.
key = name
if key.startswith("base_model.model."):
key = key[len("base_model.model.") :]
if key.endswith(".base_layer.weight"):
key = key[: -len(".base_layer.weight")] + ".weight"
elif key.endswith(".base_layer.bias"):
key = key[: -len(".base_layer.bias")] + ".bias"…alization - Replace chained str.replace() calls with explicit startswith/endswith checks to avoid unintended substitutions if patterns appear in middle of key names - Convert _save_lora_adapter to @staticmethod to eliminate code duplication in tests - Update tests to use the actual implementation instead of a duplicated standalone copy Addresses review feedback from PR volcengine#4770. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
|
Thanks for the review feedback! I've addressed both concerns in commit 3952342: 1. Unsafe String Replacement (Fixed)Replaced the chained # Before (unsafe - replaces ALL occurrences)
key = name.replace("base_model.model.", "").replace(".base_layer.weight", ".weight")
# After (safe - only handles prefix/suffix)
if key.startswith("base_model.model."):
key = key[len("base_model.model."):]
if key.endswith(".base_layer.weight"):
key = key[:-len(".base_layer.weight")] + ".weight"This prevents unintended substitutions if patterns appear in the middle of key names. 2. Code Duplication (Fixed)
All tests pass ✅ |
Summary
hf_modelcheckpoint content had incorrect state_dict key structure_save_lora_adapter()method to extract LoRA weights and normalize keys to standard HuggingFace formatProblem
When saving checkpoints with
save_contents=['model','optimizer','extra','hf_model'], LoRA-trained models were saved with:base_model.model.lm_head.weightinstead oflm_head.weightbase_model.model.model.layers.0.self_attn.q_proj.base_layer.weightinstead ofmodel.layers.0.self_attn.q_proj.weightlora_A.default.weightandlora_B.default.weightkeysThis made the saved HuggingFace checkpoints incompatible with inference engines.
Solution
Added
_save_lora_adapter()method that:lora_)huggingface/lora_adapter/base_model.model.prefix and.base_layersuffixTest plan
tests/utils/test_lora_checkpoint_on_cpu.pyFixes #4761
🤖 Generated with Claude Code