Skip to content

Commit dcd6026

Browse files
committed
update
1 parent eae7543 commit dcd6026

File tree

1 file changed

+63
-56
lines changed

1 file changed

+63
-56
lines changed

tests/models/testing_utils/quantization.py

Lines changed: 63 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,9 @@ def _test_quantization_num_parameters(self, config_kwargs):
128128
model_quantized = self._create_quantized_model(config_kwargs)
129129
num_params_quantized = model_quantized.num_parameters()
130130

131-
assert (
132-
num_params == num_params_quantized
133-
), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}"
131+
assert num_params == num_params_quantized, (
132+
f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}"
133+
)
134134

135135
def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2):
136136
model = self._load_unquantized_model()
@@ -140,9 +140,9 @@ def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_red
140140
mem_quantized = model_quantized.get_memory_footprint()
141141

142142
ratio = mem / mem_quantized
143-
assert (
144-
ratio >= expected_memory_reduction
145-
), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
143+
assert ratio >= expected_memory_reduction, (
144+
f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
145+
)
146146

147147
def _test_quantization_inference(self, config_kwargs):
148148
model_quantized = self._create_quantized_model(config_kwargs)
@@ -237,12 +237,12 @@ def _test_quantized_layers(self, config_kwargs):
237237
self._verify_if_layer_quantized(name, module, config_kwargs)
238238
num_quantized_layers += 1
239239

240-
assert (
241-
num_quantized_layers > 0
242-
), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)"
243-
assert (
244-
num_quantized_layers == expected_quantized_layers
245-
), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
240+
assert num_quantized_layers > 0, (
241+
f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)"
242+
)
243+
assert num_quantized_layers == expected_quantized_layers, (
244+
f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
245+
)
246246

247247
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
248248
"""
@@ -266,9 +266,9 @@ def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_no
266266
if any(excluded in name for excluded in modules_to_not_convert):
267267
found_excluded = True
268268
# This module should NOT be quantized
269-
assert not self._is_module_quantized(
270-
module
271-
), f"Module {name} should not be quantized but was found to be quantized"
269+
assert not self._is_module_quantized(module), (
270+
f"Module {name} should not be quantized but was found to be quantized"
271+
)
272272

273273
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}"
274274

@@ -290,9 +290,9 @@ def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_no
290290
mem_with_exclusion = model_with_exclusion.get_memory_footprint()
291291
mem_fully_quantized = model_fully_quantized.get_memory_footprint()
292292

293-
assert (
294-
mem_with_exclusion > mem_fully_quantized
295-
), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
293+
assert mem_with_exclusion > mem_fully_quantized, (
294+
f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
295+
)
296296

297297
def _test_quantization_device_map(self, config_kwargs):
298298
"""
@@ -399,40 +399,40 @@ def _create_quantized_model(self, config_kwargs, **extra_kwargs):
399399

400400
def _verify_if_layer_quantized(self, name, module, config_kwargs):
401401
expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params
402-
assert (
403-
module.weight.__class__ == expected_weight_class
404-
), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}"
402+
assert module.weight.__class__ == expected_weight_class, (
403+
f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}"
404+
)
405405

406-
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
406+
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()), ids=list(BNB_CONFIGS.keys()))
407407
def test_bnb_quantization_num_parameters(self, config_name):
408408
self._test_quantization_num_parameters(self.BNB_CONFIGS[config_name])
409409

410-
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
410+
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()), ids=list(BNB_CONFIGS.keys()))
411411
def test_bnb_quantization_memory_footprint(self, config_name):
412412
expected = self.BNB_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2)
413413
self._test_quantization_memory_footprint(self.BNB_CONFIGS[config_name], expected_memory_reduction=expected)
414414

415-
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
415+
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()), ids=list(BNB_CONFIGS.keys()))
416416
def test_bnb_quantization_inference(self, config_name):
417417
self._test_quantization_inference(self.BNB_CONFIGS[config_name])
418418

419-
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
419+
@pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"])
420420
def test_bnb_quantization_dtype_assignment(self, config_name):
421421
self._test_quantization_dtype_assignment(self.BNB_CONFIGS[config_name])
422422

423-
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
423+
@pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"])
424424
def test_bnb_quantization_lora_inference(self, config_name):
425425
self._test_quantization_lora_inference(self.BNB_CONFIGS[config_name])
426426

427-
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
427+
@pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"])
428428
def test_bnb_quantization_serialization(self, config_name):
429429
self._test_quantization_serialization(self.BNB_CONFIGS[config_name])
430430

431-
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
431+
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()), ids=list(BNB_CONFIGS.keys()))
432432
def test_bnb_quantized_layers(self, config_name):
433433
self._test_quantized_layers(self.BNB_CONFIGS[config_name])
434434

435-
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
435+
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()), ids=list(BNB_CONFIGS.keys()))
436436
def test_bnb_quantization_config_serialization(self, config_name):
437437
model = self._create_quantized_model(self.BNB_CONFIGS[config_name])
438438

@@ -469,13 +469,13 @@ def test_bnb_keep_modules_in_fp32(self):
469469
for name, module in model.named_modules():
470470
if isinstance(module, torch.nn.Linear):
471471
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
472-
assert (
473-
module.weight.dtype == torch.float32
474-
), f"Module {name} should be FP32 but is {module.weight.dtype}"
472+
assert module.weight.dtype == torch.float32, (
473+
f"Module {name} should be FP32 but is {module.weight.dtype}"
474+
)
475475
else:
476-
assert (
477-
module.weight.dtype == torch.uint8
478-
), f"Module {name} should be uint8 but is {module.weight.dtype}"
476+
assert module.weight.dtype == torch.uint8, (
477+
f"Module {name} should be uint8 but is {module.weight.dtype}"
478+
)
479479

480480
with torch.no_grad():
481481
inputs = self.get_dummy_inputs()
@@ -492,9 +492,10 @@ def test_bnb_modules_to_not_convert(self):
492492

493493
self._test_quantization_modules_to_not_convert(self.BNB_CONFIGS["4bit_nf4"], modules_to_exclude)
494494

495-
def test_bnb_device_map(self):
495+
@pytest.mark.parametrize("config_name", ["4bit_nf4", "8bit"], ids=["4bit_nf4", "8bit"])
496+
def test_bnb_device_map(self, config_name):
496497
"""Test that device_map='auto' works correctly with quantization."""
497-
self._test_quantization_device_map(self.BNB_CONFIGS["4bit_nf4"])
498+
self._test_quantization_device_map(self.BNB_CONFIGS[config_name])
498499

499500
def test_bnb_dequantize(self):
500501
"""Test that dequantize() works correctly."""
@@ -548,30 +549,36 @@ def _create_quantized_model(self, config_kwargs, **extra_kwargs):
548549
def _verify_if_layer_quantized(self, name, module, config_kwargs):
549550
assert isinstance(module, QLinear), f"Layer {name} is not QLinear, got {type(module)}"
550551

551-
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
552+
@pytest.mark.parametrize(
553+
"weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()), ids=list(QUANTO_WEIGHT_TYPES.keys())
554+
)
552555
def test_quanto_quantization_num_parameters(self, weight_type_name):
553556
self._test_quantization_num_parameters(self.QUANTO_WEIGHT_TYPES[weight_type_name])
554557

555-
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
558+
@pytest.mark.parametrize(
559+
"weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()), ids=list(QUANTO_WEIGHT_TYPES.keys())
560+
)
556561
def test_quanto_quantization_memory_footprint(self, weight_type_name):
557562
expected = self.QUANTO_EXPECTED_MEMORY_REDUCTIONS.get(weight_type_name, 1.2)
558563
self._test_quantization_memory_footprint(
559564
self.QUANTO_WEIGHT_TYPES[weight_type_name], expected_memory_reduction=expected
560565
)
561566

562-
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
567+
@pytest.mark.parametrize(
568+
"weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()), ids=list(QUANTO_WEIGHT_TYPES.keys())
569+
)
563570
def test_quanto_quantization_inference(self, weight_type_name):
564571
self._test_quantization_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name])
565572

566-
@pytest.mark.parametrize("weight_type_name", ["int8"])
573+
@pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"])
567574
def test_quanto_quantized_layers(self, weight_type_name):
568575
self._test_quantized_layers(self.QUANTO_WEIGHT_TYPES[weight_type_name])
569576

570-
@pytest.mark.parametrize("weight_type_name", ["int8"])
577+
@pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"])
571578
def test_quanto_quantization_lora_inference(self, weight_type_name):
572579
self._test_quantization_lora_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name])
573580

574-
@pytest.mark.parametrize("weight_type_name", ["int8"])
581+
@pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"])
575582
def test_quanto_quantization_serialization(self, weight_type_name):
576583
self._test_quantization_serialization(self.QUANTO_WEIGHT_TYPES[weight_type_name])
577584

@@ -636,30 +643,30 @@ def _create_quantized_model(self, config_kwargs, **extra_kwargs):
636643
def _verify_if_layer_quantized(self, name, module, config_kwargs):
637644
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
638645

639-
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
646+
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()), ids=list(TORCHAO_QUANT_TYPES.keys()))
640647
def test_torchao_quantization_num_parameters(self, quant_type):
641648
self._test_quantization_num_parameters(self.TORCHAO_QUANT_TYPES[quant_type])
642649

643-
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
650+
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()), ids=list(TORCHAO_QUANT_TYPES.keys()))
644651
def test_torchao_quantization_memory_footprint(self, quant_type):
645652
expected = self.TORCHAO_EXPECTED_MEMORY_REDUCTIONS.get(quant_type, 1.2)
646653
self._test_quantization_memory_footprint(
647654
self.TORCHAO_QUANT_TYPES[quant_type], expected_memory_reduction=expected
648655
)
649656

650-
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
657+
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()), ids=list(TORCHAO_QUANT_TYPES.keys()))
651658
def test_torchao_quantization_inference(self, quant_type):
652659
self._test_quantization_inference(self.TORCHAO_QUANT_TYPES[quant_type])
653660

654-
@pytest.mark.parametrize("quant_type", ["int8wo"])
661+
@pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"])
655662
def test_torchao_quantized_layers(self, quant_type):
656663
self._test_quantized_layers(self.TORCHAO_QUANT_TYPES[quant_type])
657664

658-
@pytest.mark.parametrize("quant_type", ["int8wo"])
665+
@pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"])
659666
def test_torchao_quantization_lora_inference(self, quant_type):
660667
self._test_quantization_lora_inference(self.TORCHAO_QUANT_TYPES[quant_type])
661668

662-
@pytest.mark.parametrize("quant_type", ["int8wo"])
669+
@pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"])
663670
def test_torchao_quantization_serialization(self, quant_type):
664671
self._test_quantization_serialization(self.TORCHAO_QUANT_TYPES[quant_type])
665672

@@ -801,34 +808,34 @@ def _create_quantized_model(self, config_kwargs, **extra_kwargs):
801808
def _verify_if_layer_quantized(self, name, module, config_kwargs):
802809
assert mtq.utils.is_quantized(module), f"Layer {name} does not have weight_quantizer attribute (not quantized)"
803810

804-
@pytest.mark.parametrize("config_name", ["fp8"])
811+
@pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"])
805812
def test_modelopt_quantization_num_parameters(self, config_name):
806813
self._test_quantization_num_parameters(self.MODELOPT_CONFIGS[config_name])
807814

808-
@pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()))
815+
@pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()), ids=list(MODELOPT_CONFIGS.keys()))
809816
def test_modelopt_quantization_memory_footprint(self, config_name):
810817
expected = self.MODELOPT_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2)
811818
self._test_quantization_memory_footprint(
812819
self.MODELOPT_CONFIGS[config_name], expected_memory_reduction=expected
813820
)
814821

815-
@pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()))
822+
@pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()), ids=list(MODELOPT_CONFIGS.keys()))
816823
def test_modelopt_quantization_inference(self, config_name):
817824
self._test_quantization_inference(self.MODELOPT_CONFIGS[config_name])
818825

819-
@pytest.mark.parametrize("config_name", ["fp8"])
826+
@pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"])
820827
def test_modelopt_quantization_dtype_assignment(self, config_name):
821828
self._test_quantization_dtype_assignment(self.MODELOPT_CONFIGS[config_name])
822829

823-
@pytest.mark.parametrize("config_name", ["fp8"])
830+
@pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"])
824831
def test_modelopt_quantization_lora_inference(self, config_name):
825832
self._test_quantization_lora_inference(self.MODELOPT_CONFIGS[config_name])
826833

827-
@pytest.mark.parametrize("config_name", ["fp8"])
834+
@pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"])
828835
def test_modelopt_quantization_serialization(self, config_name):
829836
self._test_quantization_serialization(self.MODELOPT_CONFIGS[config_name])
830837

831-
@pytest.mark.parametrize("config_name", ["fp8"])
838+
@pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"])
832839
def test_modelopt_quantized_layers(self, config_name):
833840
self._test_quantized_layers(self.MODELOPT_CONFIGS[config_name])
834841

0 commit comments

Comments
 (0)