diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index b46f2dd70..d2c19b8e1 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -56,9 +56,6 @@ set_expert_quantizer_amax, ) from .model_config import ( - KV_CACHE_FP8, - KV_CACHE_NVFP4, - KV_CACHE_NVFP4_AFFINE, QUANTIZATION_FP8, QUANTIZATION_FP8_PB_REAL, QUANTIZATION_FP8_PC_PT, @@ -647,19 +644,6 @@ def _export_transformers_checkpoint( quant_config = get_quant_config(model, is_modelopt_qlora=is_modelopt_qlora) - kv_cache_max_bound = 0 - kv_cache_format = quant_config["quantization"]["kv_cache_quant_algo"] - - cache_bound_mapping = { - KV_CACHE_NVFP4: 6 * 448, - KV_CACHE_NVFP4_AFFINE: 6 * 448, - KV_CACHE_FP8: 448, - } - - # Only update kv_cache_max_bound if a quantization is applied. - if kv_cache_format != QUANTIZATION_NONE: - kv_cache_max_bound = cache_bound_mapping.get(kv_cache_format) - # Process all quantized modules and export weights _process_quantized_modules(model, dtype, is_modelopt_qlora) @@ -669,6 +653,9 @@ def _export_transformers_checkpoint( else: quantized_state_dict = model.state_dict() + # We define kv cache scale as amax / 448 for both FP8 and NVFP4 KV cache quantization. + kv_cache_max_bound = 448 + kv_cache_format = quant_config["quantization"]["kv_cache_quant_algo"] quantized_state_dict = postprocess_state_dict( quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_qlora )