Skip to content

[Torch 2.0 compile] Fix more torch compile breaks #3313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/diffusers/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def forward(
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
t_emb = t_emb.to(dtype=sample.dtype)

emb = self.time_embedding(t_emb, timestep_cond)

Expand All @@ -517,7 +517,7 @@ def forward(

controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)

sample += controlnet_cond
sample = sample + controlnet_cond
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this break compilation? I don't get it 🤯

Copy link
Contributor

@Chillee Chillee May 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessary for torch.compile on nightly. I'm guessing on an older release it was just breaking for no good reason :) I guess perhaps in the past we didn't support in-place ops properly yet?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification!


# 3. down
down_block_res_samples = (sample,)
Expand Down Expand Up @@ -551,21 +551,22 @@ def forward(

for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples += (down_block_res_sample,)
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)

down_block_res_samples = controlnet_down_block_res_samples

mid_block_res_sample = self.controlnet_mid_block(sample)

# 6. scaling
if guess_mode:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0
scales *= conditioning_scale
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0

scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample *= scales[-1] # last one
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample *= conditioning_scale
mid_block_res_sample = mid_block_res_sample * conditioning_scale

if self.config.global_pool_conditions:
down_block_res_samples = [
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def forward(
down_block_res_samples, down_block_additional_residuals
):
down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples += (down_block_res_sample,)
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)

down_block_res_samples = new_down_block_res_samples

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def decode_latents(self, latents):
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down Expand Up @@ -728,15 +728,16 @@ def __call__(
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand All @@ -745,7 +746,7 @@ def __call__(
callback(i, t, latents)

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,8 @@ def __call__(
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
Expand All @@ -930,8 +931,8 @@ def __call__(

# compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs
).prev_sample
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,8 @@ def __call__(
encoder_hidden_states=prompt_embeds,
class_labels=noise_level,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
Expand All @@ -1048,8 +1049,8 @@ def __call__(

# compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs
).prev_sample
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,8 @@ def __call__(
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
Expand All @@ -1047,8 +1048,8 @@ def __call__(
prev_intermediate_images = intermediate_images

intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs
).prev_sample
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
)[0]

intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,8 @@ def __call__(
encoder_hidden_states=prompt_embeds,
class_labels=noise_level,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
Expand All @@ -1157,8 +1158,8 @@ def __call__(
prev_intermediate_images = intermediate_images

intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs
).prev_sample
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
)[0]

intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,8 @@ def __call__(
encoder_hidden_states=prompt_embeds,
class_labels=noise_level,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
Expand All @@ -898,8 +899,8 @@ def __call__(

# compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs
).prev_sample
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from torch import nn
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

Expand Down Expand Up @@ -579,9 +580,20 @@ def check_inputs(
)

# Check `image`
if isinstance(self.controlnet, ControlNetModel):
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
)
if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
self.check_image(image, prompt, prompt_embeds)
elif isinstance(self.controlnet, MultiControlNetModel):
elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
if not isinstance(image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`")

Expand All @@ -600,10 +612,18 @@ def check_inputs(
assert False

# Check `controlnet_conditioning_scale`
if isinstance(self.controlnet, ControlNetModel):
if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(self.controlnet, MultiControlNetModel):
elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
Expand Down Expand Up @@ -910,7 +930,14 @@ def __call__(
)

# 4. Prepare image
if isinstance(self.controlnet, ControlNetModel):
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
)
if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
image = self.prepare_image(
image=image,
width=width,
Expand All @@ -922,7 +949,11 @@ def __call__(
do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
elif isinstance(self.controlnet, MultiControlNetModel):
elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
images = []

for image_ in image:
Expand Down Expand Up @@ -1006,15 +1037,16 @@ def __call__(
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -677,15 +677,17 @@ def __call__(
latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
0
]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def decode_latents(self, latents):
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down Expand Up @@ -734,15 +734,16 @@ def __call__(
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand All @@ -751,7 +752,7 @@ def __call__(
callback(i, t, latents)

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -878,15 +878,17 @@ def __call__(
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
0
]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -690,15 +690,17 @@ def __call__(
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
0
]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# masking
if add_predicted_noise:
init_latents_proper = self.scheduler.add_noise(
Expand Down
Loading