diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 27475dc5ef8b..83ddd51c02f7 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -50,49 +50,59 @@ EXAMPLE_DOC_STRING = """ Examples: ```py - >>> # !pip install opencv-python transformers accelerate - >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler + >>> # !pip install transformers accelerate + >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch - >>> import cv2 - >>> from PIL import Image + >>> init_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ... ) + >>> init_image = init_image.resize((512, 512)) + + >>> generator = torch.Generator(device="cpu").manual_seed(1) + + >>> mask_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ... ) + >>> mask_image = mask_image.resize((512, 512)) + - >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" - >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + >>> def make_inpaint_condition(image, image_mask): + ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 + ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 - >>> init_image = load_image(img_url).resize((512, 512)) - >>> mask_image = load_image(mask_url).resize((512, 512)) + ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" + ... image[image_mask > 0.5] = -1.0 # set as masked pixel + ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) + ... image = torch.from_numpy(image) + ... return image - >>> image = np.array(init_image) - >>> # get canny image - >>> image = cv2.Canny(image, 100, 200) - >>> image = image[:, :, None] - >>> image = np.concatenate([image, image, image], axis=2) - >>> canny_image = Image.fromarray(image) + >>> control_image = make_inpaint_condition(init_image, mask_image) - >>> # load control net and stable diffusion inpainting - >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> controlnet = ControlNetModel.from_pretrained( + ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 + ... ) >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( - ... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16 + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> # speed up diffusion process with faster scheduler and memory optimization - >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) >>> pipe.enable_model_cpu_offload() >>> # generate image - >>> generator = torch.manual_seed(0) >>> image = pipe( - ... "spiderman", - ... num_inference_steps=30, + ... "a beautiful man", + ... num_inference_steps=20, ... generator=generator, + ... eta=1.0, ... image=init_image, ... mask_image=mask_image, - ... control_image=canny_image, + ... control_image=control_image, ... ).images[0] ``` """ @@ -226,6 +236,17 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + + + This pipeline can be used both with checkpoints that have been specifically fine-tuned for inpainting, such as + [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting) + as well as default text-to-image stable diffusion checkpoints, such as + [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5). + Default text-to-image stable diffusion checkpoints might be preferable for controlnets that have been fine-tuned on + those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint). + + + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. @@ -597,6 +618,16 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + def check_inputs( self, prompt, @@ -812,6 +843,8 @@ def prepare_latents( image=None, timestep=None, is_strength_max=True, + return_noise=False, + return_image_latents=False, ): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: @@ -826,32 +859,28 @@ def prepare_latents( "However, either the image or the noise timestep has not been provided." ) + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - if is_strength_max: - # if strength is 100% then simply initialise the latents to noise - latents = noise - else: - # otherwise initialise latents as init image + noise - image = image.to(device=device, dtype=dtype) - if isinstance(generator, list): - image_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) - for i in range(batch_size) - ] - else: - image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) - - image_latents = self.vae.config.scaling_factor * image_latents - - latents = self.scheduler.add_noise(image_latents, noise, timestep) + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma - return latents + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs def _default_height_width(self, height, width, image): # NOTE: It is possible that a list of images have different @@ -891,17 +920,7 @@ def prepare_mask_latents( mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) - - # encode the mask image into latents space so we can concatenate it to the latents - if isinstance(generator, list): - masked_image_latents = [ - self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i]) - for i in range(batch_size) - ] - masked_image_latents = torch.cat(masked_image_latents, dim=0) - else: - masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) - masked_image_latents = self.vae.config.scaling_factor * masked_image_latents + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: @@ -930,6 +949,21 @@ def prepare_mask_latents( masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + # override DiffusionPipeline def save_pretrained( self, @@ -954,6 +988,7 @@ def __call__( ] = None, height: Optional[int] = None, width: Optional[int] = None, + strength: float = 1.0, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -990,6 +1025,13 @@ def __call__( The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -1145,13 +1187,25 @@ def __call__( assert False # 4. Preprocess mask and image - resizes image and mask w.r.t height and width + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, mask_image, height, width, return_image=True + ) + # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels - latents = self.prepare_latents( + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, @@ -1160,10 +1214,19 @@ def __call__( device, generator, latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, ) + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + # 7. Prepare mask latent variables - mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width) mask, masked_image_latents = self.prepare_mask_latents( mask, masked_image, @@ -1213,7 +1276,9 @@ def __call__( mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + noise_pred = self.unet( latent_model_input, t, @@ -1232,6 +1297,15 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if num_channels_unet == 4: + init_latents_proper = image_latents[:1] + init_mask = mask[:1] + + if i < len(timesteps) - 1: + init_latents_proper = self.scheduler.add_noise(init_latents_proper, noise, torch.tensor([t])) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index 24b05f36f913..c8f3e8a9ee11 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -328,17 +328,7 @@ def prepare_mask_latents( mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) - - # encode the mask image into latents space so we can concatenate it to the latents - if isinstance(generator, list): - masked_image_latents = [ - self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i]) - for i in range(batch_size) - ] - masked_image_latents = torch.cat(masked_image_latents, dim=0) - else: - masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) - masked_image_latents = self.vae.config.scaling_factor * masked_image_latents + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: @@ -367,6 +357,21 @@ def prepare_mask_latents( masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): dtype = next(self.image_encoder.parameters()).dtype diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index f09db016d956..5dbac9295800 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -155,7 +155,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" - Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + Pipeline for text-guided image inpainting using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) @@ -167,6 +167,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + + It is recommended to use this pipeline with checkpoints that have been specifically fine-tuned for inpainting, such + as [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting). Default + text-to-image stable diffusion checkpoints, such as + [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) are also compatible with + this pipeline, but might be less performant. + + + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. @@ -266,14 +276,10 @@ def __init__( new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) + # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 if unet.config.in_channels != 9: - logger.warning( - f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default," - f" {self.__class__} assumes that `pipeline.unet` has 9 input channels: 4 for `num_channels_latents`," - " 1 for `num_channels_mask`, and 4 for `num_channels_masked_image`. If you did not intend to modify" - " this behavior, please check whether you have loaded the right checkpoint." - ) + logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.") self.register_modules( vae=vae, @@ -620,6 +626,8 @@ def prepare_latents( image=None, timestep=None, is_strength_max=True, + return_noise=False, + return_image_latents=False, ): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: @@ -634,32 +642,42 @@ def prepare_latents( "However, either the image or the noise timestep has not been provided." ) + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - if is_strength_max: - # if strength is 100% then simply initialise the latents to noise - latents = noise - else: - # otherwise initialise latents as init image + noise - image = image.to(device=device, dtype=dtype) - if isinstance(generator, list): - image_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) - for i in range(batch_size) - ] - else: - image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) - - image_latents = self.vae.config.scaling_factor * image_latents - - latents = self.scheduler.add_noise(image_latents, noise, timestep) + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma - return latents + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance @@ -673,17 +691,7 @@ def prepare_mask_latents( mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) - - # encode the mask image into latents space so we can concatenate it to the latents - if isinstance(generator, list): - masked_image_latents = [ - self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i]) - for i in range(batch_size) - ] - masked_image_latents = torch.cat(masked_image_latents, dim=0) - else: - masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) - masked_image_latents = self.vae.config.scaling_factor * masked_image_latents + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: @@ -916,7 +924,10 @@ def __call__( # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels - latents = self.prepare_latents( + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, @@ -928,8 +939,15 @@ def __call__( image=init_image, timestep=latent_timestep, is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, ) + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + # 7. Prepare mask latent variables mask, masked_image_latents = self.prepare_mask_latents( mask, @@ -942,17 +960,25 @@ def __call__( generator, do_classifier_free_guidance, ) + init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) + init_image = self._encode_vae_image(init_image, generator=generator) # 8. Check that sizes of mask, masked image and latents match - num_channels_mask = mask.shape[1] - num_channels_masked_image = masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" - f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." ) # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -967,7 +993,9 @@ def __call__( # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # predict the noise residual noise_pred = self.unet( @@ -986,6 +1014,15 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if num_channels_unet == 4: + init_latents_proper = image_latents[:1] + init_mask = mask[:1] + + if i < len(timesteps) - 1: + init_latents_proper = self.scheduler.add_noise(init_latents_proper, noise, torch.tensor([t])) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 5a2329a5c51f..c549d869e685 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -123,7 +123,6 @@ class StableDiffusionInpaintPipelineLegacy( """ _optional_components = ["feature_extractor"] - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ def __init__( self, vae: AutoencoderKL, @@ -137,6 +136,13 @@ def __init__( ): super().__init__() + deprecation_message = ( + f"The class {self.__class__} is deprecated and will be removed in v1.0.0. You can achieve exactly the same functionality" + "by loading your model into `StableDiffusionInpaintPipeline` instead. See https://github.com/huggingface/diffusers/pull/3533" + "for more information." + ) + deprecate("legacy is outdated", "1.0.0", deprecation_message, standard_warn=False) + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index 155286630c04..f8cc881e8650 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -163,6 +163,78 @@ def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) +class ControlNetSimpleInpaintPipelineFastTests(ControlNetInpaintPipelineFastTests): + pipeline_class = StableDiffusionControlNetInpaintPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = frozenset([]) + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + torch.manual_seed(0) + controlnet = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionControlNetInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS @@ -376,3 +448,60 @@ def test_canny(self): ) assert np.abs(expected_image - image).max() < 9e-2 + + def test_inpaint(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint") + + pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(33) + + init_image = load_image( + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ) + init_image = init_image.resize((512, 512)) + + mask_image = load_image( + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ) + mask_image = mask_image.resize((512, 512)) + + prompt = "a handsome man with ray-ban sunglasses" + + def make_inpaint_condition(image, image_mask): + image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 + image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 + + assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" + image[image_mask > 0.5] = -1.0 # set as masked pixel + image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return image + + control_image = make_inpaint_condition(init_image, mask_image) + + output = pipe( + prompt, + image=init_image, + mask_image=mask_image, + control_image=control_image, + guidance_scale=9.0, + eta=1.0, + generator=generator, + num_inference_steps=20, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/boy_ray_ban.npy" + ) + + assert np.abs(expected_image - image).max() < 9e-2 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 44de277ead07..a9337417289a 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -193,6 +193,82 @@ def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) +class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests): + pipeline_class = StableDiffusionInpaintPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = frozenset([]) + # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + scheduler = PNDMScheduler(skip_prk_steps=True) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def test_stable_diffusion_inpaint(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionInpaintPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4925, 0.4967, 0.4100, 0.5234, 0.5322, 0.4532, 0.5805, 0.5877, 0.4151]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @unittest.skip("skipped here because area stays unchanged due to mask") + def test_stable_diffusion_inpaint_lora(self): + ... + + @slow @require_torch_gpu class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase): @@ -378,6 +454,22 @@ def test_stable_diffusion_inpaint_strength_test(self): expected_slice = np.array([0.0021, 0.2350, 0.3712, 0.0575, 0.2485, 0.3451, 0.1857, 0.3156, 0.3943]) assert np.abs(expected_slice - image_slice).max() < 3e-3 + def test_stable_diffusion_simple_inpaint_ddim(self): + pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(torch_device) + image = pipe(**inputs).images + + image_slice = image[0, 253:256, 253:256, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.5157, 0.6858, 0.6873, 0.4619, 0.6416, 0.6898, 0.3702, 0.5960, 0.6935]) + + assert np.abs(expected_slice - image_slice).max() < 6e-4 + @nightly @require_torch_gpu