diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index e969d2a21a99..20370a703ef5 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1258,6 +1258,19 @@ def __call__( if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) + else: + # On CUDA/CPU we still need to align the latents dtype/device to the VAE decode dtype/device. + # Otherwise, if the VAE (or parts of it) are in fp32 while latents are fp16, group_norm/linear + # will error with Half/Float dtype mismatch. + try: + decode_param = next(iter(self.vae.post_quant_conv.parameters())) + except Exception: + decode_param = next(iter(self.vae.parameters())) + + if latents.device != decode_param.device: + latents = latents.to(device=decode_param.device) + if latents.dtype != decode_param.dtype: + latents = latents.to(dtype=decode_param.dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index b318a505e9db..493cdff63eec 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -179,6 +179,43 @@ def test_stable_diffusion_xl_euler(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_vae_decode_aligns_latents_dtype_when_vae_is_fp32(self): + # Regression test for dtype mismatch in VAE decode on non-MPS platforms. + # When the VAE (or parts of it) are fp32 but latents are fp16, the pipeline should align latents dtype + # before calling VAE.decode to avoid group_norm/linear Half/Float dtype mismatch. + device = "cpu" + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + # Force VAE to fp32. + sd_pipe.vae.to(dtype=torch.float32) + + seen = {} + + def decode_stub(z, return_dict=False): + seen["dtype"] = z.dtype + # Return a dummy image tensor. Shape is arbitrary as long as postprocess can handle BCHW. + image = torch.zeros((z.shape[0], 3, 64, 64), device=z.device, dtype=z.dtype) + return (image,) + + sd_pipe.vae.decode = decode_stub + + def callback_on_step_end(pipe, i, t, callback_kwargs): + latents = callback_kwargs["latents"].to(dtype=torch.float16) + return {"latents": latents} + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 1 + inputs["callback_on_step_end"] = callback_on_step_end + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + + _ = sd_pipe(**inputs).images + + assert seen["dtype"] == torch.float32 + def test_stable_diffusion_xl_euler_lcm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components(time_cond_proj_dim=256)