Skip to content

Commit d114d80

Browse files
[Stable Diffusion Inpainting] Allow standard text-to-img checkpoints to be useable for SD inpainting (#3533)
* Add default to inpaint * Make sure controlnet also works with normal sd for inpaint * Add tests * improve * Correct encode images function * Correct inpaint controlnet * Improve text2img inpanit * make style * up * up * up * up * fix more
1 parent e5215de commit d114d80

File tree

6 files changed

+456
-113
lines changed

6 files changed

+456
-113
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 129 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -50,49 +50,59 @@
5050
EXAMPLE_DOC_STRING = """
5151
Examples:
5252
```py
53-
>>> # !pip install opencv-python transformers accelerate
54-
>>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
53+
>>> # !pip install transformers accelerate
54+
>>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
5555
>>> from diffusers.utils import load_image
5656
>>> import numpy as np
5757
>>> import torch
5858
59-
>>> import cv2
60-
>>> from PIL import Image
59+
>>> init_image = load_image(
60+
... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
61+
... )
62+
>>> init_image = init_image.resize((512, 512))
63+
64+
>>> generator = torch.Generator(device="cpu").manual_seed(1)
65+
66+
>>> mask_image = load_image(
67+
... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
68+
... )
69+
>>> mask_image = mask_image.resize((512, 512))
70+
6171
62-
>>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
63-
>>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
72+
>>> def make_inpaint_condition(image, image_mask):
73+
... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
74+
... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
6475
65-
>>> init_image = load_image(img_url).resize((512, 512))
66-
>>> mask_image = load_image(mask_url).resize((512, 512))
76+
... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
77+
... image[image_mask > 0.5] = -1.0 # set as masked pixel
78+
... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
79+
... image = torch.from_numpy(image)
80+
... return image
6781
68-
>>> image = np.array(init_image)
6982
70-
>>> # get canny image
71-
>>> image = cv2.Canny(image, 100, 200)
72-
>>> image = image[:, :, None]
73-
>>> image = np.concatenate([image, image, image], axis=2)
74-
>>> canny_image = Image.fromarray(image)
83+
>>> control_image = make_inpaint_condition(init_image, mask_image)
7584
76-
>>> # load control net and stable diffusion inpainting
77-
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
85+
>>> controlnet = ControlNetModel.from_pretrained(
86+
... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
87+
... )
7888
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
79-
... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
89+
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
8090
... )
8191
8292
>>> # speed up diffusion process with faster scheduler and memory optimization
83-
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
93+
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
8494
8595
>>> pipe.enable_model_cpu_offload()
8696
8797
>>> # generate image
88-
>>> generator = torch.manual_seed(0)
8998
>>> image = pipe(
90-
... "spiderman",
91-
... num_inference_steps=30,
99+
... "a beautiful man",
100+
... num_inference_steps=20,
92101
... generator=generator,
102+
... eta=1.0,
93103
... image=init_image,
94104
... mask_image=mask_image,
95-
... control_image=canny_image,
105+
... control_image=control_image,
96106
... ).images[0]
97107
```
98108
"""
@@ -226,6 +236,17 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
226236
In addition the pipeline inherits the following loading methods:
227237
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
228238
239+
<Tip>
240+
241+
This pipeline can be used both with checkpoints that have been specifically fine-tuned for inpainting, such as
242+
[runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)
243+
as well as default text-to-image stable diffusion checkpoints, such as
244+
[runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5).
245+
Default text-to-image stable diffusion checkpoints might be preferable for controlnets that have been fine-tuned on
246+
those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
247+
248+
</Tip>
249+
229250
Args:
230251
vae ([`AutoencoderKL`]):
231252
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -597,6 +618,16 @@ def prepare_extra_step_kwargs(self, generator, eta):
597618
extra_step_kwargs["generator"] = generator
598619
return extra_step_kwargs
599620

621+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
622+
def get_timesteps(self, num_inference_steps, strength, device):
623+
# get the original timestep using init_timestep
624+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
625+
626+
t_start = max(num_inference_steps - init_timestep, 0)
627+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
628+
629+
return timesteps, num_inference_steps - t_start
630+
600631
def check_inputs(
601632
self,
602633
prompt,
@@ -812,6 +843,8 @@ def prepare_latents(
812843
image=None,
813844
timestep=None,
814845
is_strength_max=True,
846+
return_noise=False,
847+
return_image_latents=False,
815848
):
816849
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
817850
if isinstance(generator, list) and len(generator) != batch_size:
@@ -826,32 +859,28 @@ def prepare_latents(
826859
"However, either the image or the noise timestep has not been provided."
827860
)
828861

862+
if return_image_latents or (latents is None and not is_strength_max):
863+
image = image.to(device=device, dtype=dtype)
864+
image_latents = self._encode_vae_image(image=image, generator=generator)
865+
829866
if latents is None:
830867
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
831-
if is_strength_max:
832-
# if strength is 100% then simply initialise the latents to noise
833-
latents = noise
834-
else:
835-
# otherwise initialise latents as init image + noise
836-
image = image.to(device=device, dtype=dtype)
837-
if isinstance(generator, list):
838-
image_latents = [
839-
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
840-
for i in range(batch_size)
841-
]
842-
else:
843-
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
844-
845-
image_latents = self.vae.config.scaling_factor * image_latents
846-
847-
latents = self.scheduler.add_noise(image_latents, noise, timestep)
868+
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
848869
else:
849870
latents = latents.to(device)
850871

851872
# scale the initial noise by the standard deviation required by the scheduler
852873
latents = latents * self.scheduler.init_noise_sigma
853874

854-
return latents
875+
outputs = (latents,)
876+
877+
if return_noise:
878+
outputs += (noise,)
879+
880+
if return_image_latents:
881+
outputs += (image_latents,)
882+
883+
return outputs
855884

856885
def _default_height_width(self, height, width, image):
857886
# NOTE: It is possible that a list of images have different
@@ -891,17 +920,7 @@ def prepare_mask_latents(
891920
mask = mask.to(device=device, dtype=dtype)
892921

893922
masked_image = masked_image.to(device=device, dtype=dtype)
894-
895-
# encode the mask image into latents space so we can concatenate it to the latents
896-
if isinstance(generator, list):
897-
masked_image_latents = [
898-
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
899-
for i in range(batch_size)
900-
]
901-
masked_image_latents = torch.cat(masked_image_latents, dim=0)
902-
else:
903-
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
904-
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
923+
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
905924

906925
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
907926
if mask.shape[0] < batch_size:
@@ -930,6 +949,21 @@ def prepare_mask_latents(
930949
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
931950
return mask, masked_image_latents
932951

952+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
953+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
954+
if isinstance(generator, list):
955+
image_latents = [
956+
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
957+
for i in range(image.shape[0])
958+
]
959+
image_latents = torch.cat(image_latents, dim=0)
960+
else:
961+
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
962+
963+
image_latents = self.vae.config.scaling_factor * image_latents
964+
965+
return image_latents
966+
933967
# override DiffusionPipeline
934968
def save_pretrained(
935969
self,
@@ -954,6 +988,7 @@ def __call__(
954988
] = None,
955989
height: Optional[int] = None,
956990
width: Optional[int] = None,
991+
strength: float = 1.0,
957992
num_inference_steps: int = 50,
958993
guidance_scale: float = 7.5,
959994
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -990,6 +1025,13 @@ def __call__(
9901025
The height in pixels of the generated image.
9911026
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
9921027
The width in pixels of the generated image.
1028+
strength (`float`, *optional*, defaults to 1.):
1029+
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
1030+
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
1031+
`strength`. The number of denoising steps depends on the amount of noise initially added. When
1032+
`strength` is 1, added noise will be maximum and the denoising process will run for the full number of
1033+
iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
1034+
portion of the reference `image`.
9931035
num_inference_steps (`int`, *optional*, defaults to 50):
9941036
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
9951037
expense of slower inference.
@@ -1145,13 +1187,25 @@ def __call__(
11451187
assert False
11461188

11471189
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
1190+
mask, masked_image, init_image = prepare_mask_and_masked_image(
1191+
image, mask_image, height, width, return_image=True
1192+
)
1193+
11481194
# 5. Prepare timesteps
11491195
self.scheduler.set_timesteps(num_inference_steps, device=device)
1150-
timesteps = self.scheduler.timesteps
1196+
timesteps, num_inference_steps = self.get_timesteps(
1197+
num_inference_steps=num_inference_steps, strength=strength, device=device
1198+
)
1199+
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
1200+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1201+
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
1202+
is_strength_max = strength == 1.0
11511203

11521204
# 6. Prepare latent variables
11531205
num_channels_latents = self.vae.config.latent_channels
1154-
latents = self.prepare_latents(
1206+
num_channels_unet = self.unet.config.in_channels
1207+
return_image_latents = num_channels_unet == 4
1208+
latents_outputs = self.prepare_latents(
11551209
batch_size * num_images_per_prompt,
11561210
num_channels_latents,
11571211
height,
@@ -1160,10 +1214,19 @@ def __call__(
11601214
device,
11611215
generator,
11621216
latents,
1217+
image=init_image,
1218+
timestep=latent_timestep,
1219+
is_strength_max=is_strength_max,
1220+
return_noise=True,
1221+
return_image_latents=return_image_latents,
11631222
)
11641223

1224+
if return_image_latents:
1225+
latents, noise, image_latents = latents_outputs
1226+
else:
1227+
latents, noise = latents_outputs
1228+
11651229
# 7. Prepare mask latent variables
1166-
mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width)
11671230
mask, masked_image_latents = self.prepare_mask_latents(
11681231
mask,
11691232
masked_image,
@@ -1213,7 +1276,9 @@ def __call__(
12131276
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
12141277

12151278
# predict the noise residual
1216-
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1279+
if num_channels_unet == 9:
1280+
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1281+
12171282
noise_pred = self.unet(
12181283
latent_model_input,
12191284
t,
@@ -1232,6 +1297,15 @@ def __call__(
12321297
# compute the previous noisy sample x_t -> x_t-1
12331298
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
12341299

1300+
if num_channels_unet == 4:
1301+
init_latents_proper = image_latents[:1]
1302+
init_mask = mask[:1]
1303+
1304+
if i < len(timesteps) - 1:
1305+
init_latents_proper = self.scheduler.add_noise(init_latents_proper, noise, torch.tensor([t]))
1306+
1307+
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1308+
12351309
# call the callback, if provided
12361310
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
12371311
progress_bar.update()

src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -328,17 +328,7 @@ def prepare_mask_latents(
328328
mask = mask.to(device=device, dtype=dtype)
329329

330330
masked_image = masked_image.to(device=device, dtype=dtype)
331-
332-
# encode the mask image into latents space so we can concatenate it to the latents
333-
if isinstance(generator, list):
334-
masked_image_latents = [
335-
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
336-
for i in range(batch_size)
337-
]
338-
masked_image_latents = torch.cat(masked_image_latents, dim=0)
339-
else:
340-
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
341-
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
331+
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
342332

343333
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
344334
if mask.shape[0] < batch_size:
@@ -367,6 +357,21 @@ def prepare_mask_latents(
367357
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
368358
return mask, masked_image_latents
369359

360+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
361+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
362+
if isinstance(generator, list):
363+
image_latents = [
364+
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
365+
for i in range(image.shape[0])
366+
]
367+
image_latents = torch.cat(image_latents, dim=0)
368+
else:
369+
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
370+
371+
image_latents = self.vae.config.scaling_factor * image_latents
372+
373+
return image_latents
374+
370375
def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
371376
dtype = next(self.image_encoder.parameters()).dtype
372377

0 commit comments

Comments
 (0)