diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
index 27475dc5ef8b..83ddd51c02f7 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
@@ -50,49 +50,59 @@
EXAMPLE_DOC_STRING = """
Examples:
```py
- >>> # !pip install opencv-python transformers accelerate
- >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
+ >>> # !pip install transformers accelerate
+ >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
>>> from diffusers.utils import load_image
>>> import numpy as np
>>> import torch
- >>> import cv2
- >>> from PIL import Image
+ >>> init_image = load_image(
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
+ ... )
+ >>> init_image = init_image.resize((512, 512))
+
+ >>> generator = torch.Generator(device="cpu").manual_seed(1)
+
+ >>> mask_image = load_image(
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
+ ... )
+ >>> mask_image = mask_image.resize((512, 512))
+
- >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
- >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+ >>> def make_inpaint_condition(image, image_mask):
+ ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
+ ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
- >>> init_image = load_image(img_url).resize((512, 512))
- >>> mask_image = load_image(mask_url).resize((512, 512))
+ ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
+ ... image[image_mask > 0.5] = -1.0 # set as masked pixel
+ ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
+ ... image = torch.from_numpy(image)
+ ... return image
- >>> image = np.array(init_image)
- >>> # get canny image
- >>> image = cv2.Canny(image, 100, 200)
- >>> image = image[:, :, None]
- >>> image = np.concatenate([image, image, image], axis=2)
- >>> canny_image = Image.fromarray(image)
+ >>> control_image = make_inpaint_condition(init_image, mask_image)
- >>> # load control net and stable diffusion inpainting
- >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+ >>> controlnet = ControlNetModel.from_pretrained(
+ ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
+ ... )
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> # speed up diffusion process with faster scheduler and memory optimization
- >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+ >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
>>> pipe.enable_model_cpu_offload()
>>> # generate image
- >>> generator = torch.manual_seed(0)
>>> image = pipe(
- ... "spiderman",
- ... num_inference_steps=30,
+ ... "a beautiful man",
+ ... num_inference_steps=20,
... generator=generator,
+ ... eta=1.0,
... image=init_image,
... mask_image=mask_image,
- ... control_image=canny_image,
+ ... control_image=control_image,
... ).images[0]
```
"""
@@ -226,6 +236,17 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+
+
+ This pipeline can be used both with checkpoints that have been specifically fine-tuned for inpainting, such as
+ [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)
+ as well as default text-to-image stable diffusion checkpoints, such as
+ [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5).
+ Default text-to-image stable diffusion checkpoints might be preferable for controlnets that have been fine-tuned on
+ those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
+
+
+
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -597,6 +618,16 @@ def prepare_extra_step_kwargs(self, generator, eta):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
def check_inputs(
self,
prompt,
@@ -812,6 +843,8 @@ def prepare_latents(
image=None,
timestep=None,
is_strength_max=True,
+ return_noise=False,
+ return_image_latents=False,
):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
@@ -826,32 +859,28 @@ def prepare_latents(
"However, either the image or the noise timestep has not been provided."
)
+ if return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+
if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
- if is_strength_max:
- # if strength is 100% then simply initialise the latents to noise
- latents = noise
- else:
- # otherwise initialise latents as init image + noise
- image = image.to(device=device, dtype=dtype)
- if isinstance(generator, list):
- image_latents = [
- self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
- for i in range(batch_size)
- ]
- else:
- image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
-
- image_latents = self.vae.config.scaling_factor * image_latents
-
- latents = self.scheduler.add_noise(image_latents, noise, timestep)
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
- return latents
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
def _default_height_width(self, height, width, image):
# NOTE: It is possible that a list of images have different
@@ -891,17 +920,7 @@ def prepare_mask_latents(
mask = mask.to(device=device, dtype=dtype)
masked_image = masked_image.to(device=device, dtype=dtype)
-
- # encode the mask image into latents space so we can concatenate it to the latents
- if isinstance(generator, list):
- masked_image_latents = [
- self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
- for i in range(batch_size)
- ]
- masked_image_latents = torch.cat(masked_image_latents, dim=0)
- else:
- masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
- masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
@@ -930,6 +949,21 @@ def prepare_mask_latents(
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
+
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
# override DiffusionPipeline
def save_pretrained(
self,
@@ -954,6 +988,7 @@ def __call__(
] = None,
height: Optional[int] = None,
width: Optional[int] = None,
+ strength: float = 1.0,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -990,6 +1025,13 @@ def __call__(
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
+ strength (`float`, *optional*, defaults to 1.):
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
+ portion of the reference `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
@@ -1145,13 +1187,25 @@ def __call__(
assert False
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
+ mask, masked_image, init_image = prepare_mask_and_masked_image(
+ image, mask_image, height, width, return_image=True
+ )
+
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
- timesteps = self.scheduler.timesteps
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps=num_inference_steps, strength=strength, device=device
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
- latents = self.prepare_latents(
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+ latents_outputs = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
@@ -1160,10 +1214,19 @@ def __call__(
device,
generator,
latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=return_image_latents,
)
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
# 7. Prepare mask latent variables
- mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width)
mask, masked_image_latents = self.prepare_mask_latents(
mask,
masked_image,
@@ -1213,7 +1276,9 @@ def __call__(
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
# predict the noise residual
- latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
noise_pred = self.unet(
latent_model_input,
t,
@@ -1232,6 +1297,15 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents[:1]
+ init_mask = mask[:1]
+
+ if i < len(timesteps) - 1:
+ init_latents_proper = self.scheduler.add_noise(init_latents_proper, noise, torch.tensor([t]))
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
index 24b05f36f913..c8f3e8a9ee11 100644
--- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
+++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
@@ -328,17 +328,7 @@ def prepare_mask_latents(
mask = mask.to(device=device, dtype=dtype)
masked_image = masked_image.to(device=device, dtype=dtype)
-
- # encode the mask image into latents space so we can concatenate it to the latents
- if isinstance(generator, list):
- masked_image_latents = [
- self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
- for i in range(batch_size)
- ]
- masked_image_latents = torch.cat(masked_image_latents, dim=0)
- else:
- masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
- masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
@@ -367,6 +357,21 @@ def prepare_mask_latents(
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
+
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
dtype = next(self.image_encoder.parameters()).dtype
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index f09db016d956..5dbac9295800 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -155,7 +155,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
- Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
+ Pipeline for text-guided image inpainting using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
@@ -167,6 +167,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
+
+
+ It is recommended to use this pipeline with checkpoints that have been specifically fine-tuned for inpainting, such
+ as [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting). Default
+ text-to-image stable diffusion checkpoints, such as
+ [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) are also compatible with
+ this pipeline, but might be less performant.
+
+
+
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -266,14 +276,10 @@ def __init__(
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
+
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
if unet.config.in_channels != 9:
- logger.warning(
- f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
- f" {self.__class__} assumes that `pipeline.unet` has 9 input channels: 4 for `num_channels_latents`,"
- " 1 for `num_channels_mask`, and 4 for `num_channels_masked_image`. If you did not intend to modify"
- " this behavior, please check whether you have loaded the right checkpoint."
- )
+ logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.")
self.register_modules(
vae=vae,
@@ -620,6 +626,8 @@ def prepare_latents(
image=None,
timestep=None,
is_strength_max=True,
+ return_noise=False,
+ return_image_latents=False,
):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
@@ -634,32 +642,42 @@ def prepare_latents(
"However, either the image or the noise timestep has not been provided."
)
+ if return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+
if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
- if is_strength_max:
- # if strength is 100% then simply initialise the latents to noise
- latents = noise
- else:
- # otherwise initialise latents as init image + noise
- image = image.to(device=device, dtype=dtype)
- if isinstance(generator, list):
- image_latents = [
- self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
- for i in range(batch_size)
- ]
- else:
- image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
-
- image_latents = self.vae.config.scaling_factor * image_latents
-
- latents = self.scheduler.add_noise(image_latents, noise, timestep)
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
- return latents
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
+
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
def prepare_mask_latents(
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
@@ -673,17 +691,7 @@ def prepare_mask_latents(
mask = mask.to(device=device, dtype=dtype)
masked_image = masked_image.to(device=device, dtype=dtype)
-
- # encode the mask image into latents space so we can concatenate it to the latents
- if isinstance(generator, list):
- masked_image_latents = [
- self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
- for i in range(batch_size)
- ]
- masked_image_latents = torch.cat(masked_image_latents, dim=0)
- else:
- masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
- masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
@@ -916,7 +924,10 @@ def __call__(
# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
- latents = self.prepare_latents(
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ latents_outputs = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
@@ -928,8 +939,15 @@ def __call__(
image=init_image,
timestep=latent_timestep,
is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=return_image_latents,
)
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
# 7. Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask,
@@ -942,17 +960,25 @@ def __call__(
generator,
do_classifier_free_guidance,
)
+ init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
+ init_image = self._encode_vae_image(init_image, generator=generator)
# 8. Check that sizes of mask, masked image and latents match
- num_channels_mask = mask.shape[1]
- num_channels_masked_image = masked_image_latents.shape[1]
- if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
raise ValueError(
- f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
- f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
- f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
- " `pipeline.unet` or your `mask_image` or `image` input."
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
)
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
@@ -967,7 +993,9 @@ def __call__(
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
- latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
# predict the noise residual
noise_pred = self.unet(
@@ -986,6 +1014,15 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents[:1]
+ init_mask = mask[:1]
+
+ if i < len(timesteps) - 1:
+ init_latents_proper = self.scheduler.add_noise(init_latents_proper, noise, torch.tensor([t]))
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
index 5a2329a5c51f..c549d869e685 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
@@ -123,7 +123,6 @@ class StableDiffusionInpaintPipelineLegacy(
"""
_optional_components = ["feature_extractor"]
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__(
self,
vae: AutoencoderKL,
@@ -137,6 +136,13 @@ def __init__(
):
super().__init__()
+ deprecation_message = (
+ f"The class {self.__class__} is deprecated and will be removed in v1.0.0. You can achieve exactly the same functionality"
+ "by loading your model into `StableDiffusionInpaintPipeline` instead. See https://github.com/huggingface/diffusers/pull/3533"
+ "for more information."
+ )
+ deprecate("legacy is outdated", "1.0.0", deprecation_message, standard_warn=False)
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py
index 155286630c04..f8cc881e8650 100644
--- a/tests/pipelines/controlnet/test_controlnet_inpaint.py
+++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py
@@ -163,6 +163,78 @@ def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
+class ControlNetSimpleInpaintPipelineFastTests(ControlNetInpaintPipelineFastTests):
+ pipeline_class = StableDiffusionControlNetInpaintPipeline
+ params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
+ batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
+ image_params = frozenset([])
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=32,
+ )
+ torch.manual_seed(0)
+ controlnet = ControlNetModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ in_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ cross_attention_dim=32,
+ conditioning_embedding_out_channels=(16, 32),
+ )
+ torch.manual_seed(0)
+ scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False,
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ components = {
+ "unet": unet,
+ "controlnet": controlnet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "safety_checker": None,
+ "feature_extractor": None,
+ }
+ return components
+
+
class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionControlNetInpaintPipeline
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
@@ -376,3 +448,60 @@ def test_canny(self):
)
assert np.abs(expected_image - image).max() < 9e-2
+
+ def test_inpaint(self):
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint")
+
+ pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
+ )
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
+ pipe.enable_model_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.Generator(device="cpu").manual_seed(33)
+
+ init_image = load_image(
+ "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
+ )
+ init_image = init_image.resize((512, 512))
+
+ mask_image = load_image(
+ "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
+ )
+ mask_image = mask_image.resize((512, 512))
+
+ prompt = "a handsome man with ray-ban sunglasses"
+
+ def make_inpaint_condition(image, image_mask):
+ image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
+ image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
+
+ assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
+ image[image_mask > 0.5] = -1.0 # set as masked pixel
+ image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return image
+
+ control_image = make_inpaint_condition(init_image, mask_image)
+
+ output = pipe(
+ prompt,
+ image=init_image,
+ mask_image=mask_image,
+ control_image=control_image,
+ guidance_scale=9.0,
+ eta=1.0,
+ generator=generator,
+ num_inference_steps=20,
+ output_type="np",
+ )
+ image = output.images[0]
+
+ assert image.shape == (512, 512, 3)
+
+ expected_image = load_numpy(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/boy_ray_ban.npy"
+ )
+
+ assert np.abs(expected_image - image).max() < 9e-2
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
index 44de277ead07..a9337417289a 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
@@ -193,6 +193,82 @@ def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
+class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests):
+ pipeline_class = StableDiffusionInpaintPipeline
+ params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
+ batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
+ image_params = frozenset([])
+ # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=32,
+ )
+ scheduler = PNDMScheduler(skip_prk_steps=True)
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ components = {
+ "unet": unet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "safety_checker": None,
+ "feature_extractor": None,
+ }
+ return components
+
+ def test_stable_diffusion_inpaint(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ sd_pipe = StableDiffusionInpaintPipeline(**components)
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+ expected_slice = np.array([0.4925, 0.4967, 0.4100, 0.5234, 0.5322, 0.4532, 0.5805, 0.5877, 0.4151])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip("skipped here because area stays unchanged due to mask")
+ def test_stable_diffusion_inpaint_lora(self):
+ ...
+
+
@slow
@require_torch_gpu
class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
@@ -378,6 +454,22 @@ def test_stable_diffusion_inpaint_strength_test(self):
expected_slice = np.array([0.0021, 0.2350, 0.3712, 0.0575, 0.2485, 0.3451, 0.1857, 0.3156, 0.3943])
assert np.abs(expected_slice - image_slice).max() < 3e-3
+ def test_stable_diffusion_simple_inpaint_ddim(self):
+ pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.enable_attention_slicing()
+
+ inputs = self.get_inputs(torch_device)
+ image = pipe(**inputs).images
+
+ image_slice = image[0, 253:256, 253:256, -1].flatten()
+
+ assert image.shape == (1, 512, 512, 3)
+ expected_slice = np.array([0.5157, 0.6858, 0.6873, 0.4619, 0.6416, 0.6898, 0.3702, 0.5960, 0.6935])
+
+ assert np.abs(expected_slice - image_slice).max() < 6e-4
+
@nightly
@require_torch_gpu