Skip to content

[Stable Diffusion Inpainting] Allow standard text-to-img checkpoints to be useable for SD inpainting #3533

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 26, 2023
184 changes: 129 additions & 55 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,49 +50,59 @@
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> # !pip install opencv-python transformers accelerate
>>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
>>> # !pip install transformers accelerate
>>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
>>> from diffusers.utils import load_image
>>> import numpy as np
>>> import torch

>>> import cv2
>>> from PIL import Image
>>> init_image = load_image(
... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
... )
>>> init_image = init_image.resize((512, 512))

>>> generator = torch.Generator(device="cpu").manual_seed(1)

>>> mask_image = load_image(
... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
... )
>>> mask_image = mask_image.resize((512, 512))


>>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
>>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
>>> def make_inpaint_condition(image, image_mask):
... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0

>>> init_image = load_image(img_url).resize((512, 512))
>>> mask_image = load_image(mask_url).resize((512, 512))
... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
... image[image_mask > 0.5] = -1.0 # set as masked pixel
... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
... image = torch.from_numpy(image)
... return image

>>> image = np.array(init_image)

>>> # get canny image
>>> image = cv2.Canny(image, 100, 200)
>>> image = image[:, :, None]
>>> image = np.concatenate([image, image, image], axis=2)
>>> canny_image = Image.fromarray(image)
>>> control_image = make_inpaint_condition(init_image, mask_image)

>>> # load control net and stable diffusion inpainting
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
>>> controlnet = ControlNetModel.from_pretrained(
... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
... )
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
... )

>>> # speed up diffusion process with faster scheduler and memory optimization
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

>>> pipe.enable_model_cpu_offload()

>>> # generate image
>>> generator = torch.manual_seed(0)
>>> image = pipe(
... "spiderman",
... num_inference_steps=30,
... "a beautiful man",
... num_inference_steps=20,
... generator=generator,
... eta=1.0,
... image=init_image,
... mask_image=mask_image,
... control_image=canny_image,
... control_image=control_image,
... ).images[0]
```
"""
Expand Down Expand Up @@ -226,6 +236,17 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]

<Tip>

This pipeline can be used both with checkpoints that have been specifically fine-tuned for inpainting, such as
[runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)
as well as default text-to-image stable diffusion checkpoints, such as
[runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5).
Default text-to-image stable diffusion checkpoints might be preferable for controlnets that have been fine-tuned on
those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).

</Tip>

Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
Expand Down Expand Up @@ -597,6 +618,16 @@ def prepare_extra_step_kwargs(self, generator, eta):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]

return timesteps, num_inference_steps - t_start

def check_inputs(
self,
prompt,
Expand Down Expand Up @@ -812,6 +843,8 @@ def prepare_latents(
image=None,
timestep=None,
is_strength_max=True,
return_noise=False,
return_image_latents=False,
):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
Expand All @@ -826,32 +859,28 @@ def prepare_latents(
"However, either the image or the noise timestep has not been provided."
)

if return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)

if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if is_strength_max:
# if strength is 100% then simply initialise the latents to noise
latents = noise
else:
# otherwise initialise latents as init image + noise
image = image.to(device=device, dtype=dtype)
if isinstance(generator, list):
image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
for i in range(batch_size)
]
else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)

image_latents = self.vae.config.scaling_factor * image_latents

latents = self.scheduler.add_noise(image_latents, noise, timestep)
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
else:
latents = latents.to(device)

# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma

return latents
outputs = (latents,)

if return_noise:
outputs += (noise,)

if return_image_latents:
outputs += (image_latents,)

return outputs

def _default_height_width(self, height, width, image):
# NOTE: It is possible that a list of images have different
Expand Down Expand Up @@ -891,17 +920,7 @@ def prepare_mask_latents(
mask = mask.to(device=device, dtype=dtype)

masked_image = masked_image.to(device=device, dtype=dtype)

# encode the mask image into latents space so we can concatenate it to the latents
if isinstance(generator, list):
masked_image_latents = [
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
for i in range(batch_size)
]
masked_image_latents = torch.cat(masked_image_latents, dim=0)
else:
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)

# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
Expand Down Expand Up @@ -930,6 +949,21 @@ def prepare_mask_latents(
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)

image_latents = self.vae.config.scaling_factor * image_latents

return image_latents

# override DiffusionPipeline
def save_pretrained(
self,
Expand All @@ -954,6 +988,7 @@ def __call__(
] = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 1.0,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -990,6 +1025,13 @@ def __call__(
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
strength (`float`, *optional*, defaults to 1.):
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
`strength`. The number of denoising steps depends on the amount of noise initially added. When
`strength` is 1, added noise will be maximum and the denoising process will run for the full number of
iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
portion of the reference `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
Expand Down Expand Up @@ -1145,13 +1187,25 @@ def __call__(
assert False

# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
mask, masked_image, init_image = prepare_mask_and_masked_image(
image, mask_image, height, width, return_image=True
)

# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps=num_inference_steps, strength=strength, device=device
)
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
is_strength_max = strength == 1.0

# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
latents = self.prepare_latents(
num_channels_unet = self.unet.config.in_channels
return_image_latents = num_channels_unet == 4
latents_outputs = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
Expand All @@ -1160,10 +1214,19 @@ def __call__(
device,
generator,
latents,
image=init_image,
timestep=latent_timestep,
is_strength_max=is_strength_max,
return_noise=True,
return_image_latents=return_image_latents,
)

if return_image_latents:
latents, noise, image_latents = latents_outputs
else:
latents, noise = latents_outputs

# 7. Prepare mask latent variables
mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width)
mask, masked_image_latents = self.prepare_mask_latents(
mask,
masked_image,
Expand Down Expand Up @@ -1213,7 +1276,9 @@ def __call__(
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])

# predict the noise residual
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
if num_channels_unet == 9:
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)

noise_pred = self.unet(
latent_model_input,
t,
Expand All @@ -1232,6 +1297,15 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

if num_channels_unet == 4:
init_latents_proper = image_latents[:1]
init_mask = mask[:1]

if i < len(timesteps) - 1:
init_latents_proper = self.scheduler.add_noise(init_latents_proper, noise, torch.tensor([t]))

latents = (1 - init_mask) * init_latents_proper + init_mask * latents

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,17 +328,7 @@ def prepare_mask_latents(
mask = mask.to(device=device, dtype=dtype)

masked_image = masked_image.to(device=device, dtype=dtype)

# encode the mask image into latents space so we can concatenate it to the latents
if isinstance(generator, list):
masked_image_latents = [
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
for i in range(batch_size)
]
masked_image_latents = torch.cat(masked_image_latents, dim=0)
else:
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)

# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
Expand Down Expand Up @@ -367,6 +357,21 @@ def prepare_mask_latents(
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)

image_latents = self.vae.config.scaling_factor * image_latents

return image_latents

def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
dtype = next(self.image_encoder.parameters()).dtype

Expand Down
Loading