Skip to content

Commit d1acef4

Browse files
committed
cleanups
- don't expose ControlNetCondition - move scaling to ControlNetModel
1 parent 91affab commit d1acef4

File tree

6 files changed

+50
-50
lines changed

6 files changed

+50
-50
lines changed

src/diffusers/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@
109109
from .pipelines import (
110110
AltDiffusionImg2ImgPipeline,
111111
AltDiffusionPipeline,
112-
ControlNetCondition,
113112
CycleDiffusionPipeline,
114113
LDMTextToImagePipeline,
115114
PaintByExamplePipeline,

src/diffusers/models/controlnet.py

+5
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ def forward(
389389
timestep: Union[torch.Tensor, float, int],
390390
encoder_hidden_states: torch.Tensor,
391391
controlnet_cond: torch.FloatTensor,
392+
conditioning_scale: float = 1.0,
392393
class_labels: Optional[torch.Tensor] = None,
393394
timestep_cond: Optional[torch.Tensor] = None,
394395
attention_mask: Optional[torch.Tensor] = None,
@@ -492,6 +493,10 @@ def forward(
492493

493494
mid_block_res_sample = self.controlnet_mid_block(sample)
494495

496+
# 6. scaling
497+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
498+
mid_block_res_sample *= conditioning_scale
499+
495500
if not return_dict:
496501
return (down_block_res_samples, mid_block_res_sample)
497502

src/diffusers/pipelines/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from .paint_by_example import PaintByExamplePipeline
4747
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
4848
from .stable_diffusion import (
49-
ControlNetCondition,
5049
CycleDiffusionPipeline,
5150
StableDiffusionAttendAndExcitePipeline,
5251
StableDiffusionControlNetPipeline,

src/diffusers/pipelines/stable_diffusion/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
4545
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
4646
from .pipeline_stable_diffusion import StableDiffusionPipeline
4747
from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
48-
from .pipeline_stable_diffusion_controlnet import ControlNetCondition, StableDiffusionControlNetPipeline
48+
from .pipeline_stable_diffusion_controlnet import StableDiffusionControlNetPipeline
4949
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
5050
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
5151
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

+34-32
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,8 @@ class ControlNetCondition:
9292
def __init__(
9393
self,
9494
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]],
95-
scale: float = 1.0,
9695
):
9796
self.image_original = image
98-
self.scale = scale
9997

10098
def _default_height_width(self, height, width, image):
10199
if isinstance(image, list):
@@ -226,38 +224,28 @@ def forward(
226224
sample: torch.FloatTensor,
227225
timestep: Union[torch.Tensor, float, int],
228226
encoder_hidden_states: torch.Tensor,
229-
images: List[torch.tensor],
230-
cond_scales: List[int],
227+
controlnet_cond: List[torch.tensor],
228+
conditioning_scale: List[float],
231229
class_labels: Optional[torch.Tensor] = None,
232230
timestep_cond: Optional[torch.Tensor] = None,
233231
attention_mask: Optional[torch.Tensor] = None,
234232
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
235233
return_dict: bool = True,
236234
) -> Union[ControlNetOutput, Tuple]:
237-
if len(controlnet_conditions) != len(self.nets):
238-
raise ValueError(
239-
f"The number of specified `ControlNetCondition` does not match the number of `ControlNetModel`."
240-
f"There are {len(self.nets)} ControlNetModel(s), "
241-
f"but there are {len(controlnet_conditions)} `ControlNetCondition` in `controlnet_conditions`."
242-
)
243-
244-
for i, (image, scale, controlnet) in enumerate(zip(images, scales, self.nets)):
235+
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
245236
down_samples, mid_sample = controlnet(
246237
sample,
247238
timestep,
248239
encoder_hidden_states,
249240
image,
241+
scale,
250242
class_labels,
251243
timestep_cond,
252244
attention_mask,
253245
cross_attention_kwargs,
254246
return_dict,
255247
)
256248

257-
# scaling
258-
down_samples = [sample * cond.scale for sample in down_samples]
259-
mid_sample *= cond.scale
260-
261249
# merge samples
262250
if i == 0:
263251
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
@@ -700,8 +688,7 @@ def __call__(
700688
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
701689
callback_steps: int = 1,
702690
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
703-
controlnet_conditioning_scale: List[float] = 1.0,
704-
controlnet_conditions: Optional[List[ControlNetCondition]] = None,
691+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
705692
):
706693
r"""
707694
Function invoked when calling the pipeline for generation.
@@ -780,18 +767,28 @@ def __call__(
780767
(nsfw) content, according to the `safety_checker`.
781768
"""
782769

783-
# TODO: add conversion image array to ControlNetConditions
784-
if controlnet_conditions is None:
770+
# TODO: refactoring
771+
if isinstance(self.controlnet, ControlNetModel):
772+
controlnet_conditions = [ControlNetCondition(image=image)]
773+
elif isinstance(self.controlnet, MultiControlNet):
774+
# TODO: refactoring
785775
# let's split images over controlnets
786-
image_per_control = 1 if isinstance(self.controlnet, ControlNetModel) else len(self.controlnet.nets)
787-
if image_per_control > 1 and not isinstance(image, list):
788-
raise ValueError(...)
789-
790-
if len(image) % image_per_control != 0:
791-
raise ValueError(...)
792-
793-
images = [image[i:i+num_image_per_control] for i in range(0, len(image), image_per_control)]
794-
controlnet_conditions = [ControlNetCondition(image=image, scale=scale) for image, scale in zip(images, controlnet_conditioning_scale)]
776+
# image_per_control = 1 if isinstance(self.controlnet, ControlNetModel) else len(self.controlnet.nets)
777+
# if image_per_control > 1 and not isinstance(image, list):
778+
# raise ValueError(...)
779+
780+
# if len(image) % image_per_control != 0:
781+
# raise ValueError(...)
782+
783+
# if image_per_control > 1 and not isinstance(controlnet_conditioning_scale, list):
784+
# controlnet_conditioning_scale = [controlnet_conditioning_scale] * image_per_control
785+
786+
# images = [image[i:i+image_per_control] for i in range(0, len(image), image_per_control)]
787+
788+
num_controlnets = len(self.controlnet.nets)
789+
if num_controlnets > 1 and not isinstance(controlnet_conditioning_scale, list):
790+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * num_controlnets
791+
controlnet_conditions = [ControlNetCondition(image=img) for img in image]
795792

796793
# 0. Default height and width to unet
797794
height, width = controlnet_conditions[0].default_height_width(height, width)
@@ -829,14 +826,14 @@ def __call__(
829826
)
830827

831828
# 4. Prepare image
832-
for cond, controlnet in zip(controlnet_conditions, self.controlnet.nets):
829+
for cond in controlnet_conditions:
833830
cond.prepare_image(
834831
width=width,
835832
height=height,
836833
batch_size=batch_size * num_images_per_prompt,
837834
num_images_per_prompt=num_images_per_prompt,
838835
device=device,
839-
dtype=controlnet.dtype,
836+
dtype=self.controlnet.dtype,
840837
do_classifier_free_guidance=do_classifier_free_guidance,
841838
)
842839

@@ -869,11 +866,16 @@ def __call__(
869866
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
870867

871868
# controlnet(s) inference
869+
if len(controlnet_conditions) > 1:
870+
controlnet_cond = [c.image for c in controlnet_conditions]
871+
else:
872+
controlnet_cond = controlnet_conditions[0].image
872873
down_block_res_samples, mid_block_res_sample = self.controlnet(
873874
latent_model_input,
874875
t,
875876
encoder_hidden_states=prompt_embeds,
876-
controlnet_conditions=controlnet_conditions,
877+
controlnet_cond=controlnet_cond,
878+
conditioning_scale=controlnet_conditioning_scale,
877879
return_dict=False,
878880
)
879881

tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py

+10-15
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
from diffusers import (
2525
AutoencoderKL,
26-
ControlNetCondition,
2726
ControlNetModel,
2827
DDIMScheduler,
2928
StableDiffusionControlNetPipeline,
@@ -232,20 +231,16 @@ def get_dummy_inputs(self, device, seed=0):
232231

233232
controlnet_embedder_scale_factor = 2
234233

235-
conditions = [
236-
ControlNetCondition(
237-
image=randn_tensor(
238-
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
239-
generator=generator,
240-
device=torch.device(device),
241-
)
234+
images = [
235+
randn_tensor(
236+
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
237+
generator=generator,
238+
device=torch.device(device),
242239
),
243-
ControlNetCondition(
244-
image=randn_tensor(
245-
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
246-
generator=generator,
247-
device=torch.device(device),
248-
)
240+
randn_tensor(
241+
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
242+
generator=generator,
243+
device=torch.device(device),
249244
),
250245
]
251246

@@ -255,7 +250,7 @@ def get_dummy_inputs(self, device, seed=0):
255250
"num_inference_steps": 2,
256251
"guidance_scale": 6.0,
257252
"output_type": "numpy",
258-
"controlnet_conditions": conditions,
253+
"image": images,
259254
}
260255

261256
return inputs

0 commit comments

Comments
 (0)