Skip to content

Commit a257ed5

Browse files
committed
nits
MultiControlNet -> MultiControlNetModel - Matches existing naming a bit closer MultiControlNetModel inherit from model utils class - Don't have to re-write fp16 test Skip tests that save multi controlnet pipeline - Clearer than changing test body Don't auto-batch the number of input images to the number of controlnets. We generally like to require the user to pass the expected number of inputs. This simplifies the processing code a bit more Use existing image pre-processing code a bit more. We can rely on the existing image pre-processing code and keep the inference loop a bit simpler.
1 parent 039db1b commit a257ed5

File tree

2 files changed

+105
-110
lines changed

2 files changed

+105
-110
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

+91-80
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
import numpy as np
2121
import PIL.Image
2222
import torch
23-
from torch import device, nn
23+
from torch import nn
2424
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
2525

2626
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
2727
from ...models.controlnet import ControlNetOutput
28-
from ...models.modeling_utils import get_parameter_device, get_parameter_dtype
28+
from ...models.modeling_utils import ModelMixin
2929
from ...schedulers import KarrasDiffusionSchedulers
3030
from ...utils import (
3131
PIL_INTERPOLATION,
@@ -89,7 +89,7 @@
8989
"""
9090

9191

92-
class MultiControlNet(nn.Module):
92+
class MultiControlNetModel(ModelMixin):
9393
r"""
9494
Multiple `ControlNetModel` wrapper class for Multi-ControlNet
9595
@@ -102,25 +102,10 @@ class MultiControlNet(nn.Module):
102102
`ControlNetModel` as a list.
103103
"""
104104

105-
def __init__(self, controlnets: List[ControlNetModel]):
105+
def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
106106
super().__init__()
107107
self.nets = nn.ModuleList(controlnets)
108108

109-
@property
110-
def device(self) -> device:
111-
"""
112-
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
113-
device).
114-
"""
115-
return get_parameter_device(self)
116-
117-
@property
118-
def dtype(self) -> torch.dtype:
119-
"""
120-
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
121-
"""
122-
return get_parameter_dtype(self)
123-
124109
def forward(
125110
self,
126111
sample: torch.FloatTensor,
@@ -180,8 +165,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
180165
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
181166
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
182167
controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
183-
Provides additional conditioning to the unet during the denoising process. You can set multiple
184-
`ControlNetModel` as a list.
168+
Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
169+
as a list, the outputs from each ControlNet are added together to create one combined additional
170+
conditioning.
185171
scheduler ([`SchedulerMixin`]):
186172
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
187173
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
@@ -199,7 +185,7 @@ def __init__(
199185
text_encoder: CLIPTextModel,
200186
tokenizer: CLIPTokenizer,
201187
unet: UNet2DConditionModel,
202-
controlnet: Union[ControlNetModel, List[ControlNetModel]],
188+
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
203189
scheduler: KarrasDiffusionSchedulers,
204190
safety_checker: StableDiffusionSafetyChecker,
205191
feature_extractor: CLIPFeatureExtractor,
@@ -224,9 +210,7 @@ def __init__(
224210
)
225211

226212
if isinstance(controlnet, (list, tuple)):
227-
controlnet = MultiControlNet(controlnet)
228-
else:
229-
controlnet = controlnet
213+
controlnet = MultiControlNetModel(controlnet)
230214

231215
self.register_modules(
232216
vae=vae,
@@ -507,12 +491,14 @@ def prepare_extra_step_kwargs(self, generator, eta):
507491
def check_inputs(
508492
self,
509493
prompt,
494+
image,
510495
height,
511496
width,
512497
callback_steps,
513498
negative_prompt=None,
514499
prompt_embeds=None,
515500
negative_prompt_embeds=None,
501+
controlnet_conditioning_scale=1.0,
516502
):
517503
if height % 8 != 0 or width % 8 != 0:
518504
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -551,6 +537,40 @@ def check_inputs(
551537
f" {negative_prompt_embeds.shape}."
552538
)
553539

540+
# Check `image`
541+
542+
if isinstance(self.controlnet, ControlNetModel):
543+
self.check_image(image, prompt, prompt_embeds)
544+
elif isinstance(self.controlnet, MultiControlNetModel):
545+
if not isinstance(image, list):
546+
raise TypeError("For multiple controlnets: `image` must be type `list`")
547+
548+
if len(image) != len(self.controlnet.nets):
549+
raise ValueError(
550+
"For multiple controlnets: `image` must have the same length as the number of controlnets."
551+
)
552+
553+
for image_ in image:
554+
self.check_image(image_, prompt, prompt_embeds)
555+
else:
556+
assert False
557+
558+
# Check `controlnet_conditioning_scale`
559+
560+
if isinstance(self.controlnet, ControlNetModel):
561+
if not isinstance(controlnet_conditioning_scale, float):
562+
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
563+
elif isinstance(self.controlnet, MultiControlNetModel):
564+
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
565+
self.controlnet.nets
566+
):
567+
raise ValueError(
568+
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
569+
" the same length as the number of controlnets"
570+
)
571+
else:
572+
assert False
573+
554574
def check_image(self, image, prompt, prompt_embeds):
555575
image_is_pil = isinstance(image, PIL.Image.Image)
556576
image_is_tensor = isinstance(image, torch.Tensor)
@@ -637,7 +657,10 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
637657
return latents
638658

639659
def _default_height_width(self, height, width, image):
640-
if isinstance(image, list):
660+
# NOTE: It is possible that a list of images have different
661+
# dimensions for each image, so just checking the first image
662+
# is not _exactly_ correct, but it is simple.
663+
while isinstance(image, list):
641664
image = image[0]
642665

643666
if height is None:
@@ -658,41 +681,6 @@ def _default_height_width(self, height, width, image):
658681

659682
return height, width
660683

661-
def _prepare_images(self, image):
662-
if isinstance(self.controlnet, ControlNetModel):
663-
return [image] # convert to array for internal use
664-
else: # Multi-Controlnet
665-
if not isinstance(image, list):
666-
raise ValueError("The `image` argument needs to be specified in a `list`.")
667-
668-
num_controlnets = len(self.controlnet.nets)
669-
if len(image) % num_controlnets != 0:
670-
raise ValueError(
671-
"The length of the `image` argument list needs to be a multiple of the number of Multi-ControlNet."
672-
)
673-
674-
image_per_control = len(image) // num_controlnets
675-
676-
# let's split images over controlnets
677-
return [image[i : i + image_per_control] for i in range(0, len(image), image_per_control)]
678-
679-
def _prepare_controlnet_conditioning_scale(self, controlnet_conditioning_scale):
680-
if isinstance(self.controlnet, ControlNetModel):
681-
if not isinstance(controlnet_conditioning_scale, float):
682-
raise ValueError("The `controlnet_conditioning_scale` argument needs to be specified as a `float`.")
683-
return controlnet_conditioning_scale
684-
else: # Multi-Controlnet
685-
num_controlnets = len(self.controlnet.nets)
686-
if isinstance(controlnet_conditioning_scale, list):
687-
if len(controlnet_conditioning_scale) != num_controlnets:
688-
raise ValueError(
689-
"The length of the `controlnet_conditioning_scale` list does not match the number of Multi-ControlNet. "
690-
"If specified in `list`, it needs to have the same length as the number of Multi-ControlNet."
691-
)
692-
else:
693-
controlnet_conditioning_scale = [controlnet_conditioning_scale] * num_controlnets
694-
return controlnet_conditioning_scale
695-
696684
# override DiffusionPipeline
697685
def save_pretrained(
698686
self,
@@ -736,12 +724,14 @@ def __call__(
736724
prompt (`str` or `List[str]`, *optional*):
737725
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
738726
instead.
739-
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
727+
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
728+
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
740729
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
741730
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
742-
also be accepted as an image. The control image is automatically resized to fit the output image. If
743-
multiple ControlNets are specified in init, you need to set the corresponding images in the form of a
744-
list of `List[torch.FloatTensor]` or `List[PIL.Image.Image]`.
731+
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
732+
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
733+
specified in init, images must be passed as a list such that each element of the list can be correctly
734+
batched for input to a single controlnet.
745735
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
746736
The height in pixels of the generated image.
747737
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -807,21 +797,21 @@ def __call__(
807797
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
808798
(nsfw) content, according to the `safety_checker`.
809799
"""
810-
811-
# prepare `images` and `controlnet_conditioning_scale` for both a ControlNet and Multi-Controlnet
812-
# `images` here is a list where each element is a conditioning image for each ControlNet.
813-
images = self._prepare_images(image)
814-
controlnet_conditioning_scale = self._prepare_controlnet_conditioning_scale(controlnet_conditioning_scale)
815-
816800
# 0. Default height and width to unet
817-
height, width = self._default_height_width(height, width, images[0])
801+
height, width = self._default_height_width(height, width, image)
818802

819803
# 1. Check inputs. Raise error if not correct
820804
self.check_inputs(
821-
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
805+
prompt,
806+
image,
807+
height,
808+
width,
809+
callback_steps,
810+
negative_prompt,
811+
prompt_embeds,
812+
negative_prompt_embeds,
813+
controlnet_conditioning_scale,
822814
)
823-
for image in images:
824-
self.check_image(image, prompt, prompt_embeds)
825815

826816
# 2. Define call parameters
827817
if prompt is not None and isinstance(prompt, str):
@@ -837,6 +827,9 @@ def __call__(
837827
# corresponds to doing no classifier free guidance.
838828
do_classifier_free_guidance = guidance_scale > 1.0
839829

830+
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
831+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
832+
840833
# 3. Encode input prompt
841834
prompt_embeds = self._encode_prompt(
842835
prompt,
@@ -849,8 +842,8 @@ def __call__(
849842
)
850843

851844
# 4. Prepare image
852-
images = [
853-
self.prepare_image(
845+
if isinstance(self.controlnet, ControlNetModel):
846+
image = self.prepare_image(
854847
image=image,
855848
width=width,
856849
height=height,
@@ -860,8 +853,26 @@ def __call__(
860853
dtype=self.controlnet.dtype,
861854
do_classifier_free_guidance=do_classifier_free_guidance,
862855
)
863-
for image in images
864-
]
856+
elif isinstance(self.controlnet, MultiControlNetModel):
857+
images = []
858+
859+
for image_ in image:
860+
image_ = self.prepare_image(
861+
image=image_,
862+
width=width,
863+
height=height,
864+
batch_size=batch_size * num_images_per_prompt,
865+
num_images_per_prompt=num_images_per_prompt,
866+
device=device,
867+
dtype=self.controlnet.dtype,
868+
do_classifier_free_guidance=do_classifier_free_guidance,
869+
)
870+
871+
images.append(image_)
872+
873+
image = images
874+
else:
875+
assert False
865876

866877
# 5. Prepare timesteps
867878
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -896,7 +907,7 @@ def __call__(
896907
latent_model_input,
897908
t,
898909
encoder_hidden_states=prompt_embeds,
899-
controlnet_cond=images[0] if len(images) == 1 else images,
910+
controlnet_cond=image,
900911
conditioning_scale=controlnet_conditioning_scale,
901912
return_dict=False,
902913
)

tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py

+14-30
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
StableDiffusionControlNetPipeline,
2929
UNet2DConditionModel,
3030
)
31+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
3132
from diffusers.utils import load_image, load_numpy, randn_tensor, slow, torch_device
3233
from diffusers.utils.import_utils import is_xformers_available
3334
from diffusers.utils.testing_utils import require_torch_gpu
@@ -211,9 +212,11 @@ def get_dummy_components(self):
211212
text_encoder = CLIPTextModel(text_encoder_config)
212213
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
213214

215+
controlnet = MultiControlNetModel([controlnet1, controlnet2])
216+
214217
components = {
215218
"unet": unet,
216-
"controlnet": [controlnet1, controlnet2],
219+
"controlnet": controlnet,
217220
"scheduler": scheduler,
218221
"vae": vae,
219222
"text_encoder": text_encoder,
@@ -268,30 +271,7 @@ def test_xformers_attention_forwardGenerator_pass(self):
268271
def test_inference_batch_single_identical(self):
269272
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
270273

271-
# override PipelineTesterMixin
272-
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
273-
def test_float16_inference(self):
274-
components = self.get_dummy_components()
275-
pipe = self.pipeline_class(**components)
276-
pipe.to(torch_device)
277-
pipe.set_progress_bar_config(disable=None)
278-
279-
for name, module in components.items():
280-
if isinstance(module, list):
281-
components[name] = [m.half() for m in module if hasattr(m, "half")]
282-
elif hasattr(module, "half"):
283-
components[name] = module.half()
284-
pipe_fp16 = self.pipeline_class(**components)
285-
pipe_fp16.to(torch_device)
286-
pipe_fp16.set_progress_bar_config(disable=None)
287-
288-
output = pipe(**self.get_dummy_inputs(torch_device))[0]
289-
output_fp16 = pipe_fp16(**self.get_dummy_inputs(torch_device))[0]
290-
291-
max_diff = np.abs(output - output_fp16).max()
292-
self.assertLess(max_diff, 1e-2, "The outputs of the fp16 and fp32 pipelines are too different.")
293-
294-
def check_save_pretrained_raise_not_implemented_exception(self):
274+
def test_save_pretrained_raise_not_implemented_exception(self):
295275
components = self.get_dummy_components()
296276
pipe = self.pipeline_class(**components)
297277
pipe.to(torch_device)
@@ -304,17 +284,19 @@ def check_save_pretrained_raise_not_implemented_exception(self):
304284
pass
305285

306286
# override PipelineTesterMixin
307-
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
287+
@unittest.skip("save pretrained not implemented")
308288
def test_save_load_float16(self):
309-
self.check_save_pretrained_raise_not_implemented_exception()
289+
...
310290

311291
# override PipelineTesterMixin
292+
@unittest.skip("save pretrained not implemented")
312293
def test_save_load_local(self):
313-
self.check_save_pretrained_raise_not_implemented_exception()
294+
...
314295

315296
# override PipelineTesterMixin
297+
@unittest.skip("save pretrained not implemented")
316298
def test_save_load_optional_components(self):
317-
self.check_save_pretrained_raise_not_implemented_exception()
299+
...
318300

319301

320302
@slow
@@ -605,6 +587,8 @@ def test_pose_and_canny(self):
605587

606588
assert image.shape == (768, 512, 3)
607589

608-
expected_image = load_numpy("https://huggingface.co/takuma104/controlnet_dev/resolve/main/pose_canny_out.npy")
590+
expected_image = load_numpy(
591+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose_canny_out.npy"
592+
)
609593

610594
assert np.abs(expected_image - image).max() < 5e-2

0 commit comments

Comments
 (0)