diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 785c4a16ea4b..dca426198fdd 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -21,11 +21,8 @@ from diffusers import BitsAndBytesConfig, GGUFQuantizationConfig, NVIDIAModelOptConfig, QuantoConfig, TorchAoConfig from diffusers.utils.import_utils import ( is_bitsandbytes_available, - is_gguf_available, is_nvidia_modelopt_available, is_optimum_quanto_available, - is_torchao_available, - is_torchao_version, ) from ...testing_utils import ( @@ -59,13 +56,6 @@ if is_optimum_quanto_available(): from optimum.quanto import QLinear -if is_gguf_available(): - pass - -if is_torchao_available(): - if is_torchao_version(">=", "0.9.0"): - pass - class LoRALayer(torch.nn.Module): """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only. @@ -132,14 +122,14 @@ def _create_quantized_model(self, config_kwargs, **extra_kwargs): def _verify_if_layer_quantized(self, name, module, config_kwargs): raise NotImplementedError("Subclass must implement _verify_if_layer_quantized") - def _is_module_quantized(self, module): + def _is_module_quantized(self, module, quant_config_kwargs=None): """ Check if a module is quantized. Returns True if quantized, False otherwise. Default implementation tries _verify_if_layer_quantized and catches exceptions. Subclasses can override for more efficient checking. """ try: - self._verify_if_layer_quantized("", module, {}) + self._verify_if_layer_quantized("", module, quant_config_kwargs or {}) return True except (AssertionError, AttributeError): return False @@ -273,7 +263,9 @@ def _test_quantized_layers(self, config_kwargs): f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})" ) - def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert): + def _test_quantization_modules_to_not_convert( + self, config_kwargs, modules_to_not_convert, to_not_convert_key="modules_to_not_convert" + ): """ Test that modules specified in modules_to_not_convert are not quantized. @@ -283,7 +275,7 @@ def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_no """ # Create config with modules_to_not_convert config_kwargs_with_exclusion = config_kwargs.copy() - config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_not_convert + config_kwargs_with_exclusion[to_not_convert_key] = modules_to_not_convert model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion) @@ -295,7 +287,7 @@ def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_no if any(excluded in name for excluded in modules_to_not_convert): found_excluded = True # This module should NOT be quantized - assert not self._is_module_quantized(module), ( + assert not self._is_module_quantized(module, config_kwargs_with_exclusion), ( f"Module {name} should not be quantized but was found to be quantized" ) @@ -307,7 +299,7 @@ def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_no if isinstance(module, torch.nn.Linear): # Check if this module is NOT in the exclusion list if not any(excluded in name for excluded in modules_to_not_convert): - if self._is_module_quantized(module): + if self._is_module_quantized(module, config_kwargs_with_exclusion): found_quantized = True break @@ -612,7 +604,7 @@ def test_bnb_modules_to_not_convert(self): pytest.skip("modules_to_not_convert_for_test not defined for this model") self._test_quantization_modules_to_not_convert( - BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude + BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude, "llm_int8_skip_modules" ) @pytest.mark.parametrize("config_name", ["4bit_nf4", "8bit"], ids=["4bit_nf4", "8bit"]) @@ -811,7 +803,14 @@ def _create_quantized_model(self, config_kwargs, **extra_kwargs): return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) def _verify_if_layer_quantized(self, name, module, config_kwargs): + from torchao.dtypes import AffineQuantizedTensor + from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor + assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}" + # Check if the weight is actually quantized + weight = module.weight + is_quantized = isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)) + assert is_quantized, f"Layer {name} weight is not quantized, got {type(weight)}" # int4wo requires CUDA-specific ops (_convert_weight_to_int4pack) @@ -907,9 +906,39 @@ def test_torchao_modules_to_not_convert(self): if modules_to_exclude is None: pytest.skip("modules_to_not_convert_for_test not defined for this model") - self._test_quantization_modules_to_not_convert( - TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude - ) + # Custom implementation for torchao that skips memory footprint check + # because get_memory_footprint() doesn't accurately reflect torchao quantization + config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"] + config_kwargs_with_exclusion = config_kwargs.copy() + config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_exclude + + model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion) + + # Find a module that should NOT be quantized + found_excluded = False + for name, module in model_with_exclusion.named_modules(): + if isinstance(module, torch.nn.Linear): + # Check if this module is in the exclusion list + if any(excluded in name for excluded in modules_to_exclude): + found_excluded = True + # This module should NOT be quantized + assert not self._is_module_quantized(module, config_kwargs_with_exclusion), ( + f"Module {name} should not be quantized but was found to be quantized" + ) + + assert found_excluded, f"No linear layers found in excluded modules: {modules_to_exclude}" + + # Find a module that SHOULD be quantized (not in exclusion list) + found_quantized = False + for name, module in model_with_exclusion.named_modules(): + if isinstance(module, torch.nn.Linear): + # Check if this module is NOT in the exclusion list + if not any(excluded in name for excluded in modules_to_exclude): + if self._is_module_quantized(module, config_kwargs_with_exclusion): + found_quantized = True + break + + assert found_quantized, "No quantized layers found outside of excluded modules" def test_torchao_device_map(self): """Test that device_map='auto' works correctly with quantization.""" diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 2d39dadfcad1..4008b66e3bcb 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -318,6 +318,10 @@ def pretrained_model_name_or_path(self): class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin): """BitsAndBytes quantization tests for Flux Transformer.""" + @property + def modules_to_not_convert_for_test(self): + return ["norm_out.linear"] + class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin): """Quanto quantization tests for Flux Transformer.""" @@ -330,10 +334,18 @@ def pretrained_model_name_or_path(self): def pretrained_model_kwargs(self): return {} + @property + def modules_to_not_convert_for_test(self): + return ["norm_out.linear"] + class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin): """TorchAO quantization tests for Flux Transformer.""" + @property + def modules_to_not_convert_for_test(self): + return ["norm_out.linear"] + class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin): @property @@ -402,6 +414,10 @@ def get_dummy_inputs(self): class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin): """ModelOpt quantization tests for Flux Transformer.""" + @property + def modules_to_not_convert_for_test(self): + return ["norm_out.linear"] + class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin): """ModelOpt + compile tests for Flux Transformer."""