Skip to content

Commit c8cc4f0

Browse files
patrickvonplatenChillee
authored andcommitted
[Torch 2.0 compile] Fix more torch compile breaks (huggingface#3313)
* Fix more torch compile breaks * add tests * Fix all * fix controlnet * fix more * Add Horace He as co-author. > > Co-authored-by: Horace He <[email protected]> * Add Horace He as co-author. Co-authored-by: Horace He <[email protected]> --------- Co-authored-by: Horace He <[email protected]>
1 parent 863bb75 commit c8cc4f0

22 files changed

+219
-78
lines changed

Diff for: src/diffusers/models/controlnet.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def forward(
498498
# timesteps does not contain any weights and will always return f32 tensors
499499
# but time_embedding might actually be running in fp16. so we need to cast here.
500500
# there might be better ways to encapsulate this.
501-
t_emb = t_emb.to(dtype=self.dtype)
501+
t_emb = t_emb.to(dtype=sample.dtype)
502502

503503
emb = self.time_embedding(t_emb, timestep_cond)
504504

@@ -517,7 +517,7 @@ def forward(
517517

518518
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
519519

520-
sample += controlnet_cond
520+
sample = sample + controlnet_cond
521521

522522
# 3. down
523523
down_block_res_samples = (sample,)
@@ -551,21 +551,22 @@ def forward(
551551

552552
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
553553
down_block_res_sample = controlnet_block(down_block_res_sample)
554-
controlnet_down_block_res_samples += (down_block_res_sample,)
554+
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
555555

556556
down_block_res_samples = controlnet_down_block_res_samples
557557

558558
mid_block_res_sample = self.controlnet_mid_block(sample)
559559

560560
# 6. scaling
561561
if guess_mode:
562-
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0
563-
scales *= conditioning_scale
562+
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
563+
564+
scales = scales * conditioning_scale
564565
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
565-
mid_block_res_sample *= scales[-1] # last one
566+
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
566567
else:
567568
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
568-
mid_block_res_sample *= conditioning_scale
569+
mid_block_res_sample = mid_block_res_sample * conditioning_scale
569570

570571
if self.config.global_pool_conditions:
571572
down_block_res_samples = [

Diff for: src/diffusers/models/unet_2d_condition.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ def forward(
740740
down_block_res_samples, down_block_additional_residuals
741741
):
742742
down_block_res_sample = down_block_res_sample + down_block_additional_residual
743-
new_down_block_res_samples += (down_block_res_sample,)
743+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
744744

745745
down_block_res_samples = new_down_block_res_samples
746746

Diff for: src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def decode_latents(self, latents):
457457
FutureWarning,
458458
)
459459
latents = 1 / self.vae.config.scaling_factor * latents
460-
image = self.vae.decode(latents).sample
460+
image = self.vae.decode(latents, return_dict=False)[0]
461461
image = (image / 2 + 0.5).clamp(0, 1)
462462
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
463463
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
@@ -728,15 +728,16 @@ def __call__(
728728
t,
729729
encoder_hidden_states=prompt_embeds,
730730
cross_attention_kwargs=cross_attention_kwargs,
731-
).sample
731+
return_dict=False,
732+
)[0]
732733

733734
# perform guidance
734735
if do_classifier_free_guidance:
735736
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
736737
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
737738

738739
# compute the previous noisy sample x_t -> x_t-1
739-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
740+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
740741

741742
# call the callback, if provided
742743
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -745,7 +746,7 @@ def __call__(
745746
callback(i, t, latents)
746747

747748
if not output_type == "latent":
748-
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
749+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
749750
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
750751
else:
751752
image = latents

Diff for: src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,8 @@ def __call__(
918918
t,
919919
encoder_hidden_states=prompt_embeds,
920920
cross_attention_kwargs=cross_attention_kwargs,
921-
).sample
921+
return_dict=False,
922+
)[0]
922923

923924
# perform guidance
924925
if do_classifier_free_guidance:
@@ -930,8 +931,8 @@ def __call__(
930931

931932
# compute the previous noisy sample x_t -> x_t-1
932933
intermediate_images = self.scheduler.step(
933-
noise_pred, t, intermediate_images, **extra_step_kwargs
934-
).prev_sample
934+
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
935+
)[0]
935936

936937
# call the callback, if provided
937938
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

Diff for: src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,8 @@ def __call__(
10361036
encoder_hidden_states=prompt_embeds,
10371037
class_labels=noise_level,
10381038
cross_attention_kwargs=cross_attention_kwargs,
1039-
).sample
1039+
return_dict=False,
1040+
)[0]
10401041

10411042
# perform guidance
10421043
if do_classifier_free_guidance:
@@ -1048,8 +1049,8 @@ def __call__(
10481049

10491050
# compute the previous noisy sample x_t -> x_t-1
10501051
intermediate_images = self.scheduler.step(
1051-
noise_pred, t, intermediate_images, **extra_step_kwargs
1052-
).prev_sample
1052+
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
1053+
)[0]
10531054

10541055
# call the callback, if provided
10551056
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

Diff for: src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,8 @@ def __call__(
10331033
t,
10341034
encoder_hidden_states=prompt_embeds,
10351035
cross_attention_kwargs=cross_attention_kwargs,
1036-
).sample
1036+
return_dict=False,
1037+
)[0]
10371038

10381039
# perform guidance
10391040
if do_classifier_free_guidance:
@@ -1047,8 +1048,8 @@ def __call__(
10471048
prev_intermediate_images = intermediate_images
10481049

10491050
intermediate_images = self.scheduler.step(
1050-
noise_pred, t, intermediate_images, **extra_step_kwargs
1051-
).prev_sample
1051+
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
1052+
)[0]
10521053

10531054
intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images
10541055

Diff for: src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1143,7 +1143,8 @@ def __call__(
11431143
encoder_hidden_states=prompt_embeds,
11441144
class_labels=noise_level,
11451145
cross_attention_kwargs=cross_attention_kwargs,
1146-
).sample
1146+
return_dict=False,
1147+
)[0]
11471148

11481149
# perform guidance
11491150
if do_classifier_free_guidance:
@@ -1157,8 +1158,8 @@ def __call__(
11571158
prev_intermediate_images = intermediate_images
11581159

11591160
intermediate_images = self.scheduler.step(
1160-
noise_pred, t, intermediate_images, **extra_step_kwargs
1161-
).prev_sample
1161+
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
1162+
)[0]
11621163

11631164
intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images
11641165

Diff for: src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -886,7 +886,8 @@ def __call__(
886886
encoder_hidden_states=prompt_embeds,
887887
class_labels=noise_level,
888888
cross_attention_kwargs=cross_attention_kwargs,
889-
).sample
889+
return_dict=False,
890+
)[0]
890891

891892
# perform guidance
892893
if do_classifier_free_guidance:
@@ -898,8 +899,8 @@ def __call__(
898899

899900
# compute the previous noisy sample x_t -> x_t-1
900901
intermediate_images = self.scheduler.step(
901-
noise_pred, t, intermediate_images, **extra_step_kwargs
902-
).prev_sample
902+
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
903+
)[0]
903904

904905
# call the callback, if provided
905906
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

Diff for: src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

+40-8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import numpy as np
2121
import PIL.Image
2222
import torch
23+
import torch.nn.functional as F
2324
from torch import nn
2425
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
2526

@@ -579,9 +580,20 @@ def check_inputs(
579580
)
580581

581582
# Check `image`
582-
if isinstance(self.controlnet, ControlNetModel):
583+
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
584+
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
585+
)
586+
if (
587+
isinstance(self.controlnet, ControlNetModel)
588+
or is_compiled
589+
and isinstance(self.controlnet._orig_mod, ControlNetModel)
590+
):
583591
self.check_image(image, prompt, prompt_embeds)
584-
elif isinstance(self.controlnet, MultiControlNetModel):
592+
elif (
593+
isinstance(self.controlnet, MultiControlNetModel)
594+
or is_compiled
595+
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
596+
):
585597
if not isinstance(image, list):
586598
raise TypeError("For multiple controlnets: `image` must be type `list`")
587599

@@ -600,10 +612,18 @@ def check_inputs(
600612
assert False
601613

602614
# Check `controlnet_conditioning_scale`
603-
if isinstance(self.controlnet, ControlNetModel):
615+
if (
616+
isinstance(self.controlnet, ControlNetModel)
617+
or is_compiled
618+
and isinstance(self.controlnet._orig_mod, ControlNetModel)
619+
):
604620
if not isinstance(controlnet_conditioning_scale, float):
605621
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
606-
elif isinstance(self.controlnet, MultiControlNetModel):
622+
elif (
623+
isinstance(self.controlnet, MultiControlNetModel)
624+
or is_compiled
625+
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
626+
):
607627
if isinstance(controlnet_conditioning_scale, list):
608628
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
609629
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
@@ -910,7 +930,14 @@ def __call__(
910930
)
911931

912932
# 4. Prepare image
913-
if isinstance(self.controlnet, ControlNetModel):
933+
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
934+
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
935+
)
936+
if (
937+
isinstance(self.controlnet, ControlNetModel)
938+
or is_compiled
939+
and isinstance(self.controlnet._orig_mod, ControlNetModel)
940+
):
914941
image = self.prepare_image(
915942
image=image,
916943
width=width,
@@ -922,7 +949,11 @@ def __call__(
922949
do_classifier_free_guidance=do_classifier_free_guidance,
923950
guess_mode=guess_mode,
924951
)
925-
elif isinstance(self.controlnet, MultiControlNetModel):
952+
elif (
953+
isinstance(self.controlnet, MultiControlNetModel)
954+
or is_compiled
955+
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
956+
):
926957
images = []
927958

928959
for image_ in image:
@@ -1006,15 +1037,16 @@ def __call__(
10061037
cross_attention_kwargs=cross_attention_kwargs,
10071038
down_block_additional_residuals=down_block_res_samples,
10081039
mid_block_additional_residual=mid_block_res_sample,
1009-
).sample
1040+
return_dict=False,
1041+
)[0]
10101042

10111043
# perform guidance
10121044
if do_classifier_free_guidance:
10131045
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
10141046
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
10151047

10161048
# compute the previous noisy sample x_t -> x_t-1
1017-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1049+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
10181050

10191051
# call the callback, if provided
10201052
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

Diff for: src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -677,15 +677,17 @@ def __call__(
677677
latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1)
678678

679679
# predict the noise residual
680-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
680+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
681+
0
682+
]
681683

682684
# perform guidance
683685
if do_classifier_free_guidance:
684686
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
685687
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
686688

687689
# compute the previous noisy sample x_t -> x_t-1
688-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
690+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
689691

690692
# call the callback, if provided
691693
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

Diff for: src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def decode_latents(self, latents):
462462
FutureWarning,
463463
)
464464
latents = 1 / self.vae.config.scaling_factor * latents
465-
image = self.vae.decode(latents).sample
465+
image = self.vae.decode(latents, return_dict=False)[0]
466466
image = (image / 2 + 0.5).clamp(0, 1)
467467
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
468468
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
@@ -734,15 +734,16 @@ def __call__(
734734
t,
735735
encoder_hidden_states=prompt_embeds,
736736
cross_attention_kwargs=cross_attention_kwargs,
737-
).sample
737+
return_dict=False,
738+
)[0]
738739

739740
# perform guidance
740741
if do_classifier_free_guidance:
741742
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
742743
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
743744

744745
# compute the previous noisy sample x_t -> x_t-1
745-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
746+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
746747

747748
# call the callback, if provided
748749
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -751,7 +752,7 @@ def __call__(
751752
callback(i, t, latents)
752753

753754
if not output_type == "latent":
754-
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
755+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
755756
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
756757
else:
757758
image = latents

Diff for: src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -878,15 +878,17 @@ def __call__(
878878
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
879879

880880
# predict the noise residual
881-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
881+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
882+
0
883+
]
882884

883885
# perform guidance
884886
if do_classifier_free_guidance:
885887
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
886888
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
887889

888890
# compute the previous noisy sample x_t -> x_t-1
889-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
891+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
890892

891893
# call the callback, if provided
892894
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

Diff for: src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -690,15 +690,17 @@ def __call__(
690690
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
691691

692692
# predict the noise residual
693-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
693+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
694+
0
695+
]
694696

695697
# perform guidance
696698
if do_classifier_free_guidance:
697699
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
698700
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
699701

700702
# compute the previous noisy sample x_t -> x_t-1
701-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
703+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
702704
# masking
703705
if add_predicted_noise:
704706
init_latents_proper = self.scheduler.add_noise(

0 commit comments

Comments
 (0)