50
50
EXAMPLE_DOC_STRING = """
51
51
Examples:
52
52
```py
53
- >>> # !pip install opencv-python transformers accelerate
54
- >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
53
+ >>> # !pip install transformers accelerate
54
+ >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
55
55
>>> from diffusers.utils import load_image
56
56
>>> import numpy as np
57
57
>>> import torch
58
58
59
- >>> import cv2
60
- >>> from PIL import Image
59
+ >>> init_image = load_image(
60
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
61
+ ... )
62
+ >>> init_image = init_image.resize((512, 512))
63
+
64
+ >>> generator = torch.Generator(device="cpu").manual_seed(1)
65
+
66
+ >>> mask_image = load_image(
67
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
68
+ ... )
69
+ >>> mask_image = mask_image.resize((512, 512))
70
+
61
71
62
- >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
63
- >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
72
+ >>> def make_inpaint_condition(image, image_mask):
73
+ ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
74
+ ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
64
75
65
- >>> init_image = load_image(img_url).resize((512, 512))
66
- >>> mask_image = load_image(mask_url).resize((512, 512))
76
+ ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
77
+ ... image[image_mask > 0.5] = -1.0 # set as masked pixel
78
+ ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
79
+ ... image = torch.from_numpy(image)
80
+ ... return image
67
81
68
- >>> image = np.array(init_image)
69
82
70
- >>> # get canny image
71
- >>> image = cv2.Canny(image, 100, 200)
72
- >>> image = image[:, :, None]
73
- >>> image = np.concatenate([image, image, image], axis=2)
74
- >>> canny_image = Image.fromarray(image)
83
+ >>> control_image = make_inpaint_condition(init_image, mask_image)
75
84
76
- >>> # load control net and stable diffusion inpainting
77
- >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
85
+ >>> controlnet = ControlNetModel.from_pretrained(
86
+ ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
87
+ ... )
78
88
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
79
- ... "runwayml/stable-diffusion-inpainting ", controlnet=controlnet, torch_dtype=torch.float16
89
+ ... "runwayml/stable-diffusion-v1-5 ", controlnet=controlnet, torch_dtype=torch.float16
80
90
... )
81
91
82
92
>>> # speed up diffusion process with faster scheduler and memory optimization
83
- >>> pipe.scheduler = UniPCMultistepScheduler .from_config(pipe.scheduler.config)
93
+ >>> pipe.scheduler = DDIMScheduler .from_config(pipe.scheduler.config)
84
94
85
95
>>> pipe.enable_model_cpu_offload()
86
96
87
97
>>> # generate image
88
- >>> generator = torch.manual_seed(0)
89
98
>>> image = pipe(
90
- ... "spiderman ",
91
- ... num_inference_steps=30 ,
99
+ ... "a beautiful man ",
100
+ ... num_inference_steps=20 ,
92
101
... generator=generator,
102
+ ... eta=1.0,
93
103
... image=init_image,
94
104
... mask_image=mask_image,
95
- ... control_image=canny_image ,
105
+ ... control_image=control_image ,
96
106
... ).images[0]
97
107
```
98
108
"""
@@ -226,6 +236,17 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
226
236
In addition the pipeline inherits the following loading methods:
227
237
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
228
238
239
+ <Tip>
240
+
241
+ This pipeline can be used both with checkpoints that have been specifically fine-tuned for inpainting, such as
242
+ [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)
243
+ as well as default text-to-image stable diffusion checkpoints, such as
244
+ [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5).
245
+ Default text-to-image stable diffusion checkpoints might be preferable for controlnets that have been fine-tuned on
246
+ those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
247
+
248
+ </Tip>
249
+
229
250
Args:
230
251
vae ([`AutoencoderKL`]):
231
252
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -597,6 +618,16 @@ def prepare_extra_step_kwargs(self, generator, eta):
597
618
extra_step_kwargs ["generator" ] = generator
598
619
return extra_step_kwargs
599
620
621
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
622
+ def get_timesteps (self , num_inference_steps , strength , device ):
623
+ # get the original timestep using init_timestep
624
+ init_timestep = min (int (num_inference_steps * strength ), num_inference_steps )
625
+
626
+ t_start = max (num_inference_steps - init_timestep , 0 )
627
+ timesteps = self .scheduler .timesteps [t_start * self .scheduler .order :]
628
+
629
+ return timesteps , num_inference_steps - t_start
630
+
600
631
def check_inputs (
601
632
self ,
602
633
prompt ,
@@ -812,6 +843,8 @@ def prepare_latents(
812
843
image = None ,
813
844
timestep = None ,
814
845
is_strength_max = True ,
846
+ return_noise = False ,
847
+ return_image_latents = False ,
815
848
):
816
849
shape = (batch_size , num_channels_latents , height // self .vae_scale_factor , width // self .vae_scale_factor )
817
850
if isinstance (generator , list ) and len (generator ) != batch_size :
@@ -826,32 +859,28 @@ def prepare_latents(
826
859
"However, either the image or the noise timestep has not been provided."
827
860
)
828
861
862
+ if return_image_latents or (latents is None and not is_strength_max ):
863
+ image = image .to (device = device , dtype = dtype )
864
+ image_latents = self ._encode_vae_image (image = image , generator = generator )
865
+
829
866
if latents is None :
830
867
noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
831
- if is_strength_max :
832
- # if strength is 100% then simply initialise the latents to noise
833
- latents = noise
834
- else :
835
- # otherwise initialise latents as init image + noise
836
- image = image .to (device = device , dtype = dtype )
837
- if isinstance (generator , list ):
838
- image_latents = [
839
- self .vae .encode (image [i : i + 1 ]).latent_dist .sample (generator = generator [i ])
840
- for i in range (batch_size )
841
- ]
842
- else :
843
- image_latents = self .vae .encode (image ).latent_dist .sample (generator = generator )
844
-
845
- image_latents = self .vae .config .scaling_factor * image_latents
846
-
847
- latents = self .scheduler .add_noise (image_latents , noise , timestep )
868
+ latents = noise if is_strength_max else self .scheduler .add_noise (image_latents , noise , timestep )
848
869
else :
849
870
latents = latents .to (device )
850
871
851
872
# scale the initial noise by the standard deviation required by the scheduler
852
873
latents = latents * self .scheduler .init_noise_sigma
853
874
854
- return latents
875
+ outputs = (latents ,)
876
+
877
+ if return_noise :
878
+ outputs += (noise ,)
879
+
880
+ if return_image_latents :
881
+ outputs += (image_latents ,)
882
+
883
+ return outputs
855
884
856
885
def _default_height_width (self , height , width , image ):
857
886
# NOTE: It is possible that a list of images have different
@@ -891,17 +920,7 @@ def prepare_mask_latents(
891
920
mask = mask .to (device = device , dtype = dtype )
892
921
893
922
masked_image = masked_image .to (device = device , dtype = dtype )
894
-
895
- # encode the mask image into latents space so we can concatenate it to the latents
896
- if isinstance (generator , list ):
897
- masked_image_latents = [
898
- self .vae .encode (masked_image [i : i + 1 ]).latent_dist .sample (generator = generator [i ])
899
- for i in range (batch_size )
900
- ]
901
- masked_image_latents = torch .cat (masked_image_latents , dim = 0 )
902
- else :
903
- masked_image_latents = self .vae .encode (masked_image ).latent_dist .sample (generator = generator )
904
- masked_image_latents = self .vae .config .scaling_factor * masked_image_latents
923
+ masked_image_latents = self ._encode_vae_image (masked_image , generator = generator )
905
924
906
925
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
907
926
if mask .shape [0 ] < batch_size :
@@ -930,6 +949,21 @@ def prepare_mask_latents(
930
949
masked_image_latents = masked_image_latents .to (device = device , dtype = dtype )
931
950
return mask , masked_image_latents
932
951
952
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
953
+ def _encode_vae_image (self , image : torch .Tensor , generator : torch .Generator ):
954
+ if isinstance (generator , list ):
955
+ image_latents = [
956
+ self .vae .encode (image [i : i + 1 ]).latent_dist .sample (generator = generator [i ])
957
+ for i in range (image .shape [0 ])
958
+ ]
959
+ image_latents = torch .cat (image_latents , dim = 0 )
960
+ else :
961
+ image_latents = self .vae .encode (image ).latent_dist .sample (generator = generator )
962
+
963
+ image_latents = self .vae .config .scaling_factor * image_latents
964
+
965
+ return image_latents
966
+
933
967
# override DiffusionPipeline
934
968
def save_pretrained (
935
969
self ,
@@ -954,6 +988,7 @@ def __call__(
954
988
] = None ,
955
989
height : Optional [int ] = None ,
956
990
width : Optional [int ] = None ,
991
+ strength : float = 1.0 ,
957
992
num_inference_steps : int = 50 ,
958
993
guidance_scale : float = 7.5 ,
959
994
negative_prompt : Optional [Union [str , List [str ]]] = None ,
@@ -990,6 +1025,13 @@ def __call__(
990
1025
The height in pixels of the generated image.
991
1026
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
992
1027
The width in pixels of the generated image.
1028
+ strength (`float`, *optional*, defaults to 1.):
1029
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
1030
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
1031
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
1032
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
1033
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
1034
+ portion of the reference `image`.
993
1035
num_inference_steps (`int`, *optional*, defaults to 50):
994
1036
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
995
1037
expense of slower inference.
@@ -1145,13 +1187,25 @@ def __call__(
1145
1187
assert False
1146
1188
1147
1189
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
1190
+ mask , masked_image , init_image = prepare_mask_and_masked_image (
1191
+ image , mask_image , height , width , return_image = True
1192
+ )
1193
+
1148
1194
# 5. Prepare timesteps
1149
1195
self .scheduler .set_timesteps (num_inference_steps , device = device )
1150
- timesteps = self .scheduler .timesteps
1196
+ timesteps , num_inference_steps = self .get_timesteps (
1197
+ num_inference_steps = num_inference_steps , strength = strength , device = device
1198
+ )
1199
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
1200
+ latent_timestep = timesteps [:1 ].repeat (batch_size * num_images_per_prompt )
1201
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
1202
+ is_strength_max = strength == 1.0
1151
1203
1152
1204
# 6. Prepare latent variables
1153
1205
num_channels_latents = self .vae .config .latent_channels
1154
- latents = self .prepare_latents (
1206
+ num_channels_unet = self .unet .config .in_channels
1207
+ return_image_latents = num_channels_unet == 4
1208
+ latents_outputs = self .prepare_latents (
1155
1209
batch_size * num_images_per_prompt ,
1156
1210
num_channels_latents ,
1157
1211
height ,
@@ -1160,10 +1214,19 @@ def __call__(
1160
1214
device ,
1161
1215
generator ,
1162
1216
latents ,
1217
+ image = init_image ,
1218
+ timestep = latent_timestep ,
1219
+ is_strength_max = is_strength_max ,
1220
+ return_noise = True ,
1221
+ return_image_latents = return_image_latents ,
1163
1222
)
1164
1223
1224
+ if return_image_latents :
1225
+ latents , noise , image_latents = latents_outputs
1226
+ else :
1227
+ latents , noise = latents_outputs
1228
+
1165
1229
# 7. Prepare mask latent variables
1166
- mask , masked_image = prepare_mask_and_masked_image (image , mask_image , height , width )
1167
1230
mask , masked_image_latents = self .prepare_mask_latents (
1168
1231
mask ,
1169
1232
masked_image ,
@@ -1213,7 +1276,9 @@ def __call__(
1213
1276
mid_block_res_sample = torch .cat ([torch .zeros_like (mid_block_res_sample ), mid_block_res_sample ])
1214
1277
1215
1278
# predict the noise residual
1216
- latent_model_input = torch .cat ([latent_model_input , mask , masked_image_latents ], dim = 1 )
1279
+ if num_channels_unet == 9 :
1280
+ latent_model_input = torch .cat ([latent_model_input , mask , masked_image_latents ], dim = 1 )
1281
+
1217
1282
noise_pred = self .unet (
1218
1283
latent_model_input ,
1219
1284
t ,
@@ -1232,6 +1297,15 @@ def __call__(
1232
1297
# compute the previous noisy sample x_t -> x_t-1
1233
1298
latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs , return_dict = False )[0 ]
1234
1299
1300
+ if num_channels_unet == 4 :
1301
+ init_latents_proper = image_latents [:1 ]
1302
+ init_mask = mask [:1 ]
1303
+
1304
+ if i < len (timesteps ) - 1 :
1305
+ init_latents_proper = self .scheduler .add_noise (init_latents_proper , noise , torch .tensor ([t ]))
1306
+
1307
+ latents = (1 - init_mask ) * init_latents_proper + init_mask * latents
1308
+
1235
1309
# call the callback, if provided
1236
1310
if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
1237
1311
progress_bar .update ()
0 commit comments