Skip to content

Conversation

@yurekami
Copy link

Summary

  • Fix dtype mismatch crash when resuming Muon optimizer training from checkpoint with bf16 enabled
  • Add load_state_dict override to cast momentum_buffer to match parameter dtype after loading
  • Add unit test for bf16 checkpoint resume scenario

Root Cause

When resuming training from a checkpoint with bf16 enabled:

  1. momentum_buffer is saved as fp32 in the checkpoint
  2. After load_state_dict(), momentum_buffer remains fp32
  3. Gradients in bf16 mode are bf16
  4. momentum.lerp_(grad, 1 - beta) crashes due to dtype mismatch

Fix

Added load_state_dict override to all Muon optimizer classes that casts optimizer state buffers to match the parameter dtype after loading:

Class Buffers Fixed
Muon momentum_buffer
SingleDeviceMuon momentum_buffer
MuonWithAuxAdam momentum_buffer, exp_avg, exp_avg_sq
SingleDeviceMuonWithAuxAdam momentum_buffer, exp_avg, exp_avg_sq

Test Plan

  • Added TestMuonBF16CheckpointResume test class that:
    1. Creates model with bf16 enabled + Muon optimizer
    2. Trains for a few steps (creates momentum_buffer state)
    3. Saves checkpoint
    4. Loads checkpoint
    5. Resumes training (validates fix)
  • Tests both ZeRO stage 1 and stage 2

Fixes: #7746

🤖 Generated with Claude Code

When resuming training from a checkpoint with bf16 enabled, the Muon
optimizer's momentum_buffer was loaded as fp32 (from the checkpoint)
while gradients were bf16, causing a dtype mismatch error in the
lerp_() operation.

This fix adds a load_state_dict override to all Muon optimizer classes
(Muon, SingleDeviceMuon, MuonWithAuxAdam, SingleDeviceMuonWithAuxAdam)
that casts the momentum_buffer (and exp_avg/exp_avg_sq for hybrid
classes) to match the parameter dtype after loading the checkpoint.

Fixes: deepspeedai#7746

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: yurekami <yurekami@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] Dtype mismatch (bf16 vs fp32) when resuming Muon optimizer from checkpoint

1 participant