From ddc8f2e57567ee7eb1a7e3723ba9779689e7b9bd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 2 May 2023 10:19:28 +0200 Subject: [PATCH 1/7] Fix more torch compile breaks --- .../alt_diffusion/pipeline_alt_diffusion_img2img.py | 9 +++++---- .../pipelines/deepfloyd_if/pipeline_if_img2img.py | 7 ++++--- .../deepfloyd_if/pipeline_if_img2img_superresolution.py | 7 ++++--- .../pipelines/deepfloyd_if/pipeline_if_inpainting.py | 7 ++++--- .../pipeline_if_inpainting_superresolution.py | 7 ++++--- .../deepfloyd_if/pipeline_if_superresolution.py | 7 ++++--- .../pipeline_stable_diffusion_controlnet.py | 5 +++-- .../pipeline_stable_diffusion_depth2img.py | 6 ++++-- .../pipeline_stable_diffusion_img2img.py | 9 +++++---- .../pipeline_stable_diffusion_inpaint.py | 6 ++++-- .../pipeline_stable_diffusion_inpaint_legacy.py | 6 ++++-- .../pipeline_stable_diffusion_instruct_pix2pix.py | 6 ++++-- .../pipelines/stable_diffusion/pipeline_stable_unclip.py | 8 +++++--- .../stable_diffusion/pipeline_stable_unclip_img2img.py | 5 +++-- 14 files changed, 57 insertions(+), 38 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 5df9bab3ae41..cabed8f017ce 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -457,7 +457,7 @@ def decode_latents(self, latents): FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -728,7 +728,8 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -736,7 +737,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -745,7 +746,7 @@ def __call__( callback(i, t, latents) if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor).sample + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py index fac4adeea463..231ee02b1bb8 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py @@ -918,7 +918,8 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -930,8 +931,8 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( - noise_pred, t, intermediate_images, **extra_step_kwargs - ).prev_sample + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py index eed1bb43e5d8..770676c15984 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py @@ -1036,7 +1036,8 @@ def __call__( encoder_hidden_states=prompt_embeds, class_labels=noise_level, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -1048,8 +1049,8 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( - noise_pred, t, intermediate_images, **extra_step_kwargs - ).prev_sample + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py index d3651f5169c1..6986387ca995 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py @@ -1033,7 +1033,8 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -1047,8 +1048,8 @@ def __call__( prev_intermediate_images = intermediate_images intermediate_images = self.scheduler.step( - noise_pred, t, intermediate_images, **extra_step_kwargs - ).prev_sample + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py index 5ea6a47082ae..2b42d3992ed8 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py @@ -1143,7 +1143,8 @@ def __call__( encoder_hidden_states=prompt_embeds, class_labels=noise_level, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -1157,8 +1158,8 @@ def __call__( prev_intermediate_images = intermediate_images intermediate_images = self.scheduler.step( - noise_pred, t, intermediate_images, **extra_step_kwargs - ).prev_sample + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py index a62a51b0972f..4729cec3e4d7 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py @@ -886,7 +886,8 @@ def __call__( encoder_hidden_states=prompt_embeds, class_labels=noise_level, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -898,8 +899,8 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( - noise_pred, t, intermediate_images, **extra_step_kwargs - ).prev_sample + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index e36b0bcdf759..b5f38fdd3efe 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -1006,7 +1006,8 @@ def __call__( cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -1014,7 +1015,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 378eb927ca52..16f96bbc2fd5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -677,7 +677,9 @@ def __call__( latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[ + 0 + ] # perform guidance if do_classifier_free_guidance: @@ -685,7 +687,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 5e9a0f9e350b..2dfa730549ab 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -462,7 +462,7 @@ def decode_latents(self, latents): FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -734,7 +734,8 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -742,7 +743,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -751,7 +752,7 @@ def __call__( callback(i, t, latents) if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor).sample + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index cac7465298cc..859a34677317 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -878,7 +878,9 @@ def __call__( latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[ + 0 + ] # perform guidance if do_classifier_free_guidance: @@ -886,7 +888,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 6d93fba2425e..990c0e838f35 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -690,7 +690,9 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[ + 0 + ] # perform guidance if do_classifier_free_guidance: @@ -698,7 +700,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # masking if add_predicted_noise: init_latents_proper = self.scheduler.add_noise( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 225e3719b98f..b9dd3aa24b11 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -346,7 +346,9 @@ def __call__( scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) # predict the noise residual - noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet( + scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False + )[0] # Hack: # For karras style schedulers the model does classifer free guidance using the @@ -376,7 +378,7 @@ def __call__( noise_pred = (noise_pred - latents) / (-sigma) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 3e34dcb98132..51ba24c65873 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -830,7 +830,8 @@ def __call__( timestep=t, sample=prior_latents, **prior_extra_step_kwargs, - ).prev_sample + return_dict=False, + )[0] if callback is not None and i % callback_steps == 0: callback(i, t, prior_latents) @@ -903,7 +904,8 @@ def __call__( encoder_hidden_states=prompt_embeds, class_labels=image_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -911,7 +913,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback is not None and i % callback_steps == 0: callback(i, t, latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 9d6a6c8332fb..fce82a5bb61f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -799,7 +799,8 @@ def __call__( encoder_hidden_states=prompt_embeds, class_labels=image_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -807,7 +808,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback is not None and i % callback_steps == 0: callback(i, t, latents) From d67794befe5c8928acb7e739cfd5275377499360 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 2 May 2023 10:31:44 +0200 Subject: [PATCH 2/7] add tests --- .../test_stable_diffusion_controlnet.py | 43 +++++++++++++++++++ .../test_stable_diffusion_img2img.py | 23 ++++++++++ .../test_stable_diffusion_inpaint.py | 26 +++++++++++ 3 files changed, 92 insertions(+) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py index 70b3652fce77..bf3032ee42a1 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py @@ -19,6 +19,7 @@ import numpy as np import torch +from packaging import version from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import ( @@ -585,6 +586,48 @@ def test_canny_guess_mode(self): expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_compile(self): + if version.parse(torch.__version__) >= version.parse("2.0"): + print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0") + return + + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.to("cuda") + pipe.set_progress_bar_config(disable=None) + + pipe.unet.to(memory_format=torch.channels_last) + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + + pipe.controlnet.to(memory_format=torch.channels_last) + pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + + output = pipe( + prompt, + image, + generator=generator, + output_type="np", + num_inference_steps=3, + guidance_scale=3.0, + guess_mode=True, + ) + + image = output.images[0] + assert image.shape == (768, 512, 3) + + image_slice = image[-3:, -3:, -1] + expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + @slow @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 123f5464dfaa..62a259507a43 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -18,6 +18,7 @@ import unittest import numpy as np +from packaging import version import torch from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -460,6 +461,28 @@ def test_img2img_safety_checker_works(self): assert out.nsfw_content_detected[0], f"Safety checker should work for prompt: {inputs['prompt']}" assert np.abs(out.images[0]).sum() < 1e-5 # should be all zeros + def test_img2img_compile(self): + if version.parse(torch.__version__) >= version.parse("2.0"): + print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0") + return + + pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + pipe.unet.to(memory_format=torch.channels_last) + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + + inputs = self.get_inputs(torch_device) + image = pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 768, 3) + expected_slice = np.array([0.0593, 0.0607, 0.0851, 0.0582, 0.0636, 0.0721, 0.0751, 0.0981, 0.0781]) + + assert np.abs(expected_slice - image_slice).max() < 1e-3 + @nightly @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 290d9b0a9134..ff28ba3da8dc 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -18,6 +18,7 @@ import unittest import numpy as np +from packaging import version import torch from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -274,6 +275,31 @@ def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self): # make sure that less than 2.2 GB is allocated assert mem_bytes < 2.2 * 10**9 + def test_inpaint_compile(self): + if version.parse(torch.__version__) >= version.parse("2.0"): + print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0") + return + + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", safety_checker=None + ) + pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + pipe.unet.to(memory_format=torch.channels_last) + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + + inputs = self.get_inputs(torch_device) + image = pipe(**inputs).images + image_slice = image[0, 253:256, 253:256, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272]) + + assert np.abs(expected_slice - image_slice).max() < 1e-4 + assert np.abs(expected_slice - image_slice).max() < 1e-3 + @nightly @require_torch_gpu From 2f83bed4e8e6f24bb4a02ecb0bee45ad3adbda7f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 2 May 2023 08:36:54 +0000 Subject: [PATCH 3/7] Fix all --- tests/pipelines/stable_diffusion/test_stable_diffusion.py | 2 +- .../stable_diffusion/test_stable_diffusion_controlnet.py | 2 +- .../pipelines/stable_diffusion/test_stable_diffusion_img2img.py | 2 +- .../pipelines/stable_diffusion/test_stable_diffusion_inpaint.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index e1334e1ddd3b..a1005a5e9f55 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -923,7 +923,7 @@ def test_download_ckpt_diff_format_is_same(self): assert np.max(np.abs(image - image_ckpt)) < 1e-4 def test_stable_diffusion_compile(self): - if version.parse(torch.__version__) >= version.parse("2.0"): + if version.parse(torch.__version__) < version.parse("2.0"): print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0") return diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py index bf3032ee42a1..732142ce4c2d 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py @@ -587,7 +587,7 @@ def test_canny_guess_mode(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 def test_stable_diffusion_compile(self): - if version.parse(torch.__version__) >= version.parse("2.0"): + if version.parse(torch.__version__) < version.parse("2.0"): print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0") return diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 62a259507a43..0532139dedfc 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -462,7 +462,7 @@ def test_img2img_safety_checker_works(self): assert np.abs(out.images[0]).sum() < 1e-5 # should be all zeros def test_img2img_compile(self): - if version.parse(torch.__version__) >= version.parse("2.0"): + if version.parse(torch.__version__) < version.parse("2.0"): print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0") return diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index ff28ba3da8dc..a73a69ae3bd5 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -276,7 +276,7 @@ def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self): assert mem_bytes < 2.2 * 10**9 def test_inpaint_compile(self): - if version.parse(torch.__version__) >= version.parse("2.0"): + if version.parse(torch.__version__) < version.parse("2.0"): print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0") return From b7338e8512174bf7ded2a5c8eae51dadb97eacda Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 2 May 2023 09:44:32 +0000 Subject: [PATCH 4/7] fix controlnet --- src/diffusers/models/controlnet.py | 13 +++--- .../pipeline_stable_diffusion_controlnet.py | 43 +++++++++++++++--- .../pipeline_stable_diffusion_upscale.py | 10 +++-- .../stable_diffusion/test_stable_diffusion.py | 44 +++++++++---------- .../test_stable_diffusion_controlnet.py | 22 ++++------ .../test_stable_diffusion_img2img.py | 2 +- .../test_stable_diffusion_inpaint.py | 2 +- 7 files changed, 83 insertions(+), 53 deletions(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 3ffbb04eb222..532bb2c3b8f7 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -498,7 +498,7 @@ def forward( # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) + t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) @@ -517,7 +517,7 @@ def forward( controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) - sample += controlnet_cond + sample = sample + controlnet_cond # 3. down down_block_res_samples = (sample,) @@ -559,13 +559,14 @@ def forward( # 6. scaling if guess_mode: - scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0 - scales *= conditioning_scale + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + + scales = scales * conditioning_scale down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] - mid_block_res_sample *= scales[-1] # last one + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one else: down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] - mid_block_res_sample *= conditioning_scale + mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: down_block_res_samples = [ diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index b5f38fdd3efe..5e8e68823b34 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -20,6 +20,7 @@ import numpy as np import PIL.Image import torch +import torch.nn.functional as F from torch import nn from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer @@ -579,9 +580,20 @@ def check_inputs( ) # Check `image` - if isinstance(self.controlnet, ControlNetModel): + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): self.check_image(image, prompt, prompt_embeds) - elif isinstance(self.controlnet, MultiControlNetModel): + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") @@ -600,10 +612,18 @@ def check_inputs( assert False # Check `controlnet_conditioning_scale` - if isinstance(self.controlnet, ControlNetModel): + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - elif isinstance(self.controlnet, MultiControlNetModel): + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") @@ -910,7 +930,14 @@ def __call__( ) # 4. Prepare image - if isinstance(self.controlnet, ControlNetModel): + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): image = self.prepare_image( image=image, width=width, @@ -922,7 +949,11 @@ def __call__( do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) - elif isinstance(self.controlnet, MultiControlNetModel): + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): images = [] for image_ in image: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index a8c29f32e9e5..da1575289c8e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -678,8 +678,12 @@ def __call__( # predict the noise residual noise_pred = self.unet( - latent_model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level - ).sample + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=noise_level, + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -687,7 +691,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index a1005a5e9f55..4583cc42e6f1 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -866,6 +866,28 @@ def test_stable_diffusion_textual_inversion(self): max_diff = np.abs(expected_image - image).max() assert max_diff < 5e-2 + def test_stable_diffusion_compile(self): + if version.parse(torch.__version__) < version.parse("2.0"): + print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0") + return + + sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None) + sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + + sd_pipe.unet.to(memory_format=torch.channels_last) + sd_pipe.unet = torch.compile(sd_pipe.unet, mode="reduce-overhead", fullgraph=True) + + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(torch_device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.38019, 0.28647, 0.27321, 0.40377, 0.38290, 0.35446, 0.39218, 0.38165, 0.42239]) + assert np.abs(image_slice - expected_slice).max() < 5e-3 + @slow @require_torch_gpu @@ -922,28 +944,6 @@ def test_download_ckpt_diff_format_is_same(self): assert np.max(np.abs(image - image_ckpt)) < 1e-4 - def test_stable_diffusion_compile(self): - if version.parse(torch.__version__) < version.parse("2.0"): - print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0") - return - - sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None) - sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config) - sd_pipe = sd_pipe.to(torch_device) - - sd_pipe.unet.to(memory_format=torch.channels_last) - sd_pipe.unet = torch.compile(sd_pipe.unet, mode="reduce-overhead", fullgraph=True) - - sd_pipe.set_progress_bar_config(disable=None) - - inputs = self.get_inputs(torch_device) - image = sd_pipe(**inputs).images - image_slice = image[0, -3:, -3:, -1].flatten() - - assert image.shape == (1, 512, 512, 3) - expected_slice = np.array([0.38019, 0.28647, 0.27321, 0.40377, 0.38290, 0.35446, 0.39218, 0.38165, 0.42239]) - assert np.abs(image_slice - expected_slice).max() < 1e-4 - @nightly @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py index 732142ce4c2d..279df4a32b29 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py @@ -606,27 +606,21 @@ def test_stable_diffusion_compile(self): pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True) generator = torch.Generator(device="cpu").manual_seed(0) - prompt = "" + prompt = "bird" image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" ) - output = pipe( - prompt, - image, - generator=generator, - output_type="np", - num_inference_steps=3, - guidance_scale=3.0, - guess_mode=True, - ) - + output = pipe(prompt, image, generator=generator, output_type="np") image = output.images[0] + assert image.shape == (768, 512, 3) - image_slice = image[-3:, -3:, -1] - expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy" + ) + + assert np.abs(expected_image - image).max() < 1e-1 @slow diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 0532139dedfc..2f63371c1a0d 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -18,8 +18,8 @@ import unittest import numpy as np -from packaging import version import torch +from packaging import version from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import ( diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index a73a69ae3bd5..20977c346ecc 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -18,8 +18,8 @@ import unittest import numpy as np -from packaging import version import torch +from packaging import version from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer From dc8318f51dcd58b80d05fada46c55029466d1f42 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 2 May 2023 11:16:18 +0000 Subject: [PATCH 5/7] fix more --- src/diffusers/models/controlnet.py | 2 +- src/diffusers/models/unet_2d_condition.py | 2 +- .../pipelines/versatile_diffusion/modeling_text_unet.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 532bb2c3b8f7..7b36d2eed96a 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -551,7 +551,7 @@ def forward( for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): down_block_res_sample = controlnet_block(down_block_res_sample) - controlnet_down_block_res_samples += (down_block_res_sample,) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = controlnet_down_block_res_samples diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 83169455fa3e..2a4c9fd72c1b 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -740,7 +740,7 @@ def forward( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual - new_down_block_res_samples += (down_block_res_sample,) + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index e9e31d67905b..f0a210339c46 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -843,7 +843,7 @@ def forward( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual - new_down_block_res_samples += (down_block_res_sample,) + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples From bbdaf9b2e9e3c0b859015e6b2c2d4366369b9582 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 2 May 2023 11:19:36 +0000 Subject: [PATCH 6/7] Add Horace He as co-author. > > Co-authored-by: Horace He From eef3e598892cad83f2cdcb689ea9d254e4bfaa52 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 2 May 2023 11:21:11 +0000 Subject: [PATCH 7/7] Add Horace He as co-author. Co-authored-by: Horace He