Skip to content

Commit 7599624

Browse files
patrickvonplatenJimmy
authored and
Jimmy
committed
Fix controlnet guess mode euler (huggingface#3571)
* Fix guess mode controlnet for euler-like schedulers * make style * Co-authored-by: Chanchana Sornsoontorn <[email protected]> * Add co author Co-authored-by: Chanchana Sornsoontorn <[email protected]> * 2nd try Co-authored-by: Chanchana Sornsoontorn <[email protected]>
1 parent 00b3356 commit 7599624

File tree

4 files changed

+47
-9
lines changed

4 files changed

+47
-9
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -956,14 +956,15 @@ def __call__(
956956
# controlnet(s) inference
957957
if guess_mode and do_classifier_free_guidance:
958958
# Infer ControlNet only for the conditional batch.
959-
controlnet_latent_model_input = latents
959+
control_model_input = latents
960+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
960961
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
961962
else:
962-
controlnet_latent_model_input = latent_model_input
963+
control_model_input = latent_model_input
963964
controlnet_prompt_embeds = prompt_embeds
964965

965966
down_block_res_samples, mid_block_res_sample = self.controlnet(
966-
controlnet_latent_model_input,
967+
control_model_input,
967968
t,
968969
encoder_hidden_states=controlnet_prompt_embeds,
969970
controlnet_cond=image,

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1034,14 +1034,15 @@ def __call__(
10341034
# controlnet(s) inference
10351035
if guess_mode and do_classifier_free_guidance:
10361036
# Infer ControlNet only for the conditional batch.
1037-
controlnet_latent_model_input = latents
1037+
control_model_input = latents
1038+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
10381039
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
10391040
else:
1040-
controlnet_latent_model_input = latent_model_input
1041+
control_model_input = latent_model_input
10411042
controlnet_prompt_embeds = prompt_embeds
10421043

10431044
down_block_res_samples, mid_block_res_sample = self.controlnet(
1044-
controlnet_latent_model_input,
1045+
control_model_input,
10451046
t,
10461047
encoder_hidden_states=controlnet_prompt_embeds,
10471048
controlnet_cond=control_image,

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1248,16 +1248,18 @@ def __call__(
12481248
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
12491249
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
12501250

1251+
# controlnet(s) inference
12511252
if guess_mode and do_classifier_free_guidance:
12521253
# Infer ControlNet only for the conditional batch.
1253-
controlnet_latent_model_input = latents
1254+
control_model_input = latents
1255+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
12541256
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
12551257
else:
1256-
controlnet_latent_model_input = latent_model_input
1258+
control_model_input = latent_model_input
12571259
controlnet_prompt_embeds = prompt_embeds
12581260

12591261
down_block_res_samples, mid_block_res_sample = self.controlnet(
1260-
controlnet_latent_model_input,
1262+
control_model_input,
12611263
t,
12621264
encoder_hidden_states=controlnet_prompt_embeds,
12631265
controlnet_cond=control_image,

tests/pipelines/controlnet/test_controlnet.py

+34
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
AutoencoderKL,
2727
ControlNetModel,
2828
DDIMScheduler,
29+
EulerDiscreteScheduler,
2930
StableDiffusionControlNetPipeline,
3031
UNet2DConditionModel,
3132
)
@@ -644,6 +645,39 @@ def test_canny_guess_mode(self):
644645
expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887])
645646
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
646647

648+
def test_canny_guess_mode_euler(self):
649+
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
650+
651+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
652+
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
653+
)
654+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
655+
pipe.enable_model_cpu_offload()
656+
pipe.set_progress_bar_config(disable=None)
657+
658+
generator = torch.Generator(device="cpu").manual_seed(0)
659+
prompt = ""
660+
image = load_image(
661+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
662+
)
663+
664+
output = pipe(
665+
prompt,
666+
image,
667+
generator=generator,
668+
output_type="np",
669+
num_inference_steps=3,
670+
guidance_scale=3.0,
671+
guess_mode=True,
672+
)
673+
674+
image = output.images[0]
675+
assert image.shape == (768, 512, 3)
676+
677+
image_slice = image[-3:, -3:, -1]
678+
expected_slice = np.array([0.1655, 0.1721, 0.1623, 0.1685, 0.1711, 0.1646, 0.1651, 0.1631, 0.1494])
679+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
680+
647681
@require_torch_2
648682
def test_stable_diffusion_compile(self):
649683
run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None)

0 commit comments

Comments
 (0)