Skip to content

Commit 6766a81

Browse files
authored
Support non square image generation for StableDiffusionSAGPipeline (#2629)
* Support non square image generation for StableDiffusionSAGPipeline * Fix style
1 parent bbab855 commit 6766a81

File tree

2 files changed

+93
-61
lines changed

2 files changed

+93
-61
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

+71-61
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import inspect
16-
import math
1716
from typing import Any, Callable, Dict, List, Optional, Union
1817

1918
import torch
@@ -606,64 +605,73 @@ def __call__(
606605
store_processor = CrossAttnStoreProcessor()
607606
self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor
608607
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
609-
with self.progress_bar(total=num_inference_steps) as progress_bar:
610-
for i, t in enumerate(timesteps):
611-
# expand the latents if we are doing classifier free guidance
612-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
613-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
614-
615-
# predict the noise residual
616-
noise_pred = self.unet(
617-
latent_model_input,
618-
t,
619-
encoder_hidden_states=prompt_embeds,
620-
cross_attention_kwargs=cross_attention_kwargs,
621-
).sample
622-
623-
# perform guidance
624-
if do_classifier_free_guidance:
625-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
626-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
627-
628-
# perform self-attention guidance with the stored self-attentnion map
629-
if do_self_attention_guidance:
630-
# classifier-free guidance produces two chunks of attention map
631-
# and we only use unconditional one according to equation (24)
632-
# in https://arxiv.org/pdf/2210.00939.pdf
608+
609+
map_size = None
610+
611+
def get_map_size(module, input, output):
612+
nonlocal map_size
613+
map_size = output.sample.shape[-2:]
614+
615+
with self.unet.mid_block.attentions[0].register_forward_hook(get_map_size):
616+
with self.progress_bar(total=num_inference_steps) as progress_bar:
617+
for i, t in enumerate(timesteps):
618+
# expand the latents if we are doing classifier free guidance
619+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
620+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
621+
622+
# predict the noise residual
623+
624+
noise_pred = self.unet(
625+
latent_model_input,
626+
t,
627+
encoder_hidden_states=prompt_embeds,
628+
cross_attention_kwargs=cross_attention_kwargs,
629+
).sample
630+
631+
# perform guidance
633632
if do_classifier_free_guidance:
634-
# DDIM-like prediction of x0
635-
pred_x0 = self.pred_x0(latents, noise_pred_uncond, t)
636-
# get the stored attention maps
637-
uncond_attn, cond_attn = store_processor.attention_probs.chunk(2)
638-
# self-attention-based degrading of latents
639-
degraded_latents = self.sag_masking(
640-
pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t)
641-
)
642-
uncond_emb, _ = prompt_embeds.chunk(2)
643-
# forward and give guidance
644-
degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample
645-
noise_pred += sag_scale * (noise_pred_uncond - degraded_pred)
646-
else:
647-
# DDIM-like prediction of x0
648-
pred_x0 = self.pred_x0(latents, noise_pred, t)
649-
# get the stored attention maps
650-
cond_attn = store_processor.attention_probs
651-
# self-attention-based degrading of latents
652-
degraded_latents = self.sag_masking(
653-
pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t)
654-
)
655-
# forward and give guidance
656-
degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample
657-
noise_pred += sag_scale * (noise_pred - degraded_pred)
658-
659-
# compute the previous noisy sample x_t -> x_t-1
660-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
661-
662-
# call the callback, if provided
663-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
664-
progress_bar.update()
665-
if callback is not None and i % callback_steps == 0:
666-
callback(i, t, latents)
633+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
634+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
635+
636+
# perform self-attention guidance with the stored self-attentnion map
637+
if do_self_attention_guidance:
638+
# classifier-free guidance produces two chunks of attention map
639+
# and we only use unconditional one according to equation (24)
640+
# in https://arxiv.org/pdf/2210.00939.pdf
641+
if do_classifier_free_guidance:
642+
# DDIM-like prediction of x0
643+
pred_x0 = self.pred_x0(latents, noise_pred_uncond, t)
644+
# get the stored attention maps
645+
uncond_attn, cond_attn = store_processor.attention_probs.chunk(2)
646+
# self-attention-based degrading of latents
647+
degraded_latents = self.sag_masking(
648+
pred_x0, uncond_attn, map_size, t, self.pred_epsilon(latents, noise_pred_uncond, t)
649+
)
650+
uncond_emb, _ = prompt_embeds.chunk(2)
651+
# forward and give guidance
652+
degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample
653+
noise_pred += sag_scale * (noise_pred_uncond - degraded_pred)
654+
else:
655+
# DDIM-like prediction of x0
656+
pred_x0 = self.pred_x0(latents, noise_pred, t)
657+
# get the stored attention maps
658+
cond_attn = store_processor.attention_probs
659+
# self-attention-based degrading of latents
660+
degraded_latents = self.sag_masking(
661+
pred_x0, cond_attn, map_size, t, self.pred_epsilon(latents, noise_pred, t)
662+
)
663+
# forward and give guidance
664+
degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample
665+
noise_pred += sag_scale * (noise_pred - degraded_pred)
666+
667+
# compute the previous noisy sample x_t -> x_t-1
668+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
669+
670+
# call the callback, if provided
671+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
672+
progress_bar.update()
673+
if callback is not None and i % callback_steps == 0:
674+
callback(i, t, latents)
667675

668676
# 8. Post-processing
669677
image = self.decode_latents(latents)
@@ -680,20 +688,22 @@ def __call__(
680688

681689
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
682690

683-
def sag_masking(self, original_latents, attn_map, t, eps):
691+
def sag_masking(self, original_latents, attn_map, map_size, t, eps):
684692
# Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf
685693
bh, hw1, hw2 = attn_map.shape
686694
b, latent_channel, latent_h, latent_w = original_latents.shape
687695
h = self.unet.attention_head_dim
688696
if isinstance(h, list):
689697
h = h[-1]
690-
map_size = math.isqrt(hw1)
691698

692699
# Produce attention mask
693700
attn_map = attn_map.reshape(b, h, hw1, hw2)
694701
attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0
695702
attn_mask = (
696-
attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype)
703+
attn_mask.reshape(b, map_size[0], map_size[1])
704+
.unsqueeze(1)
705+
.repeat(1, latent_channel, 1, 1)
706+
.type(attn_map.dtype)
697707
)
698708
attn_mask = F.interpolate(attn_mask, (latent_h, latent_w))
699709

tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py

+22
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,25 @@ def test_stable_diffusion_2(self):
160160
expected_slice = np.array([0.3459, 0.2876, 0.2537, 0.3002, 0.2671, 0.2160, 0.3026, 0.2262, 0.2371])
161161

162162
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
163+
164+
def test_stable_diffusion_2_non_square(self):
165+
sag_pipe = StableDiffusionSAGPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
166+
sag_pipe = sag_pipe.to(torch_device)
167+
sag_pipe.set_progress_bar_config(disable=None)
168+
169+
prompt = "."
170+
generator = torch.manual_seed(0)
171+
output = sag_pipe(
172+
[prompt],
173+
width=768,
174+
height=512,
175+
generator=generator,
176+
guidance_scale=7.5,
177+
sag_scale=1.0,
178+
num_inference_steps=20,
179+
output_type="np",
180+
)
181+
182+
image = output.images
183+
184+
assert image.shape == (1, 512, 768, 3)

0 commit comments

Comments
 (0)