Skip to content

Commit a7fbbe1

Browse files
Fix controlnet guess mode euler (huggingface#3571)
* Fix guess mode controlnet for euler-like schedulers * make style * Co-authored-by: Chanchana Sornsoontorn <[email protected]> * Add co author Co-authored-by: Chanchana Sornsoontorn <[email protected]> * 2nd try Co-authored-by: Chanchana Sornsoontorn <[email protected]>
1 parent b3abb5e commit a7fbbe1

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

pipelines/controlnet/pipeline_controlnet.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -956,14 +956,15 @@ def __call__(
956956
# controlnet(s) inference
957957
if guess_mode and do_classifier_free_guidance:
958958
# Infer ControlNet only for the conditional batch.
959-
controlnet_latent_model_input = latents
959+
control_model_input = latents
960+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
960961
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
961962
else:
962-
controlnet_latent_model_input = latent_model_input
963+
control_model_input = latent_model_input
963964
controlnet_prompt_embeds = prompt_embeds
964965

965966
down_block_res_samples, mid_block_res_sample = self.controlnet(
966-
controlnet_latent_model_input,
967+
control_model_input,
967968
t,
968969
encoder_hidden_states=controlnet_prompt_embeds,
969970
controlnet_cond=image,

pipelines/controlnet/pipeline_controlnet_img2img.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1034,14 +1034,15 @@ def __call__(
10341034
# controlnet(s) inference
10351035
if guess_mode and do_classifier_free_guidance:
10361036
# Infer ControlNet only for the conditional batch.
1037-
controlnet_latent_model_input = latents
1037+
control_model_input = latents
1038+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
10381039
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
10391040
else:
1040-
controlnet_latent_model_input = latent_model_input
1041+
control_model_input = latent_model_input
10411042
controlnet_prompt_embeds = prompt_embeds
10421043

10431044
down_block_res_samples, mid_block_res_sample = self.controlnet(
1044-
controlnet_latent_model_input,
1045+
control_model_input,
10451046
t,
10461047
encoder_hidden_states=controlnet_prompt_embeds,
10471048
controlnet_cond=control_image,

pipelines/controlnet/pipeline_controlnet_inpaint.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1248,16 +1248,18 @@ def __call__(
12481248
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
12491249
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
12501250

1251+
# controlnet(s) inference
12511252
if guess_mode and do_classifier_free_guidance:
12521253
# Infer ControlNet only for the conditional batch.
1253-
controlnet_latent_model_input = latents
1254+
control_model_input = latents
1255+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
12541256
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
12551257
else:
1256-
controlnet_latent_model_input = latent_model_input
1258+
control_model_input = latent_model_input
12571259
controlnet_prompt_embeds = prompt_embeds
12581260

12591261
down_block_res_samples, mid_block_res_sample = self.controlnet(
1260-
controlnet_latent_model_input,
1262+
control_model_input,
12611263
t,
12621264
encoder_hidden_states=controlnet_prompt_embeds,
12631265
controlnet_cond=control_image,

0 commit comments

Comments
 (0)