13
13
# limitations under the License.
14
14
15
15
import inspect
16
- import math
17
16
from typing import Any , Callable , Dict , List , Optional , Union
18
17
19
18
import torch
@@ -606,64 +605,73 @@ def __call__(
606
605
store_processor = CrossAttnStoreProcessor ()
607
606
self .unet .mid_block .attentions [0 ].transformer_blocks [0 ].attn1 .processor = store_processor
608
607
num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
609
- with self . progress_bar ( total = num_inference_steps ) as progress_bar :
610
- for i , t in enumerate ( timesteps ):
611
- # expand the latents if we are doing classifier free guidance
612
- latent_model_input = torch . cat ([ latents ] * 2 ) if do_classifier_free_guidance else latents
613
- latent_model_input = self . scheduler . scale_model_input ( latent_model_input , t )
614
-
615
- # predict the noise residual
616
- noise_pred = self .unet (
617
- latent_model_input ,
618
- t ,
619
- encoder_hidden_states = prompt_embeds ,
620
- cross_attention_kwargs = cross_attention_kwargs ,
621
- ). sample
622
-
623
- # perform guidance
624
- if do_classifier_free_guidance :
625
- noise_pred_uncond , noise_pred_text = noise_pred . chunk ( 2 )
626
- noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond )
627
-
628
- # perform self-attention guidance with the stored self-attentnion map
629
- if do_self_attention_guidance :
630
- # classifier-free guidance produces two chunks of attention map
631
- # and we only use unconditional one according to equation (24)
632
- # in https://arxiv.org/pdf/2210.00939.pdf
608
+
609
+ map_size = None
610
+
611
+ def get_map_size ( module , input , output ):
612
+ nonlocal map_size
613
+ map_size = output . sample . shape [ - 2 :]
614
+
615
+ with self .unet . mid_block . attentions [ 0 ]. register_forward_hook ( get_map_size ):
616
+ with self . progress_bar ( total = num_inference_steps ) as progress_bar :
617
+ for i , t in enumerate ( timesteps ):
618
+ # expand the latents if we are doing classifier free guidance
619
+ latent_model_input = torch . cat ([ latents ] * 2 ) if do_classifier_free_guidance else latents
620
+ latent_model_input = self . scheduler . scale_model_input ( latent_model_input , t )
621
+
622
+ # predict the noise residual
623
+
624
+ noise_pred = self . unet (
625
+ latent_model_input ,
626
+ t ,
627
+ encoder_hidden_states = prompt_embeds ,
628
+ cross_attention_kwargs = cross_attention_kwargs ,
629
+ ). sample
630
+
631
+ # perform guidance
633
632
if do_classifier_free_guidance :
634
- # DDIM-like prediction of x0
635
- pred_x0 = self .pred_x0 (latents , noise_pred_uncond , t )
636
- # get the stored attention maps
637
- uncond_attn , cond_attn = store_processor .attention_probs .chunk (2 )
638
- # self-attention-based degrading of latents
639
- degraded_latents = self .sag_masking (
640
- pred_x0 , uncond_attn , t , self .pred_epsilon (latents , noise_pred_uncond , t )
641
- )
642
- uncond_emb , _ = prompt_embeds .chunk (2 )
643
- # forward and give guidance
644
- degraded_pred = self .unet (degraded_latents , t , encoder_hidden_states = uncond_emb ).sample
645
- noise_pred += sag_scale * (noise_pred_uncond - degraded_pred )
646
- else :
647
- # DDIM-like prediction of x0
648
- pred_x0 = self .pred_x0 (latents , noise_pred , t )
649
- # get the stored attention maps
650
- cond_attn = store_processor .attention_probs
651
- # self-attention-based degrading of latents
652
- degraded_latents = self .sag_masking (
653
- pred_x0 , cond_attn , t , self .pred_epsilon (latents , noise_pred , t )
654
- )
655
- # forward and give guidance
656
- degraded_pred = self .unet (degraded_latents , t , encoder_hidden_states = prompt_embeds ).sample
657
- noise_pred += sag_scale * (noise_pred - degraded_pred )
658
-
659
- # compute the previous noisy sample x_t -> x_t-1
660
- latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs ).prev_sample
661
-
662
- # call the callback, if provided
663
- if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
664
- progress_bar .update ()
665
- if callback is not None and i % callback_steps == 0 :
666
- callback (i , t , latents )
633
+ noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
634
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
635
+
636
+ # perform self-attention guidance with the stored self-attentnion map
637
+ if do_self_attention_guidance :
638
+ # classifier-free guidance produces two chunks of attention map
639
+ # and we only use unconditional one according to equation (24)
640
+ # in https://arxiv.org/pdf/2210.00939.pdf
641
+ if do_classifier_free_guidance :
642
+ # DDIM-like prediction of x0
643
+ pred_x0 = self .pred_x0 (latents , noise_pred_uncond , t )
644
+ # get the stored attention maps
645
+ uncond_attn , cond_attn = store_processor .attention_probs .chunk (2 )
646
+ # self-attention-based degrading of latents
647
+ degraded_latents = self .sag_masking (
648
+ pred_x0 , uncond_attn , map_size , t , self .pred_epsilon (latents , noise_pred_uncond , t )
649
+ )
650
+ uncond_emb , _ = prompt_embeds .chunk (2 )
651
+ # forward and give guidance
652
+ degraded_pred = self .unet (degraded_latents , t , encoder_hidden_states = uncond_emb ).sample
653
+ noise_pred += sag_scale * (noise_pred_uncond - degraded_pred )
654
+ else :
655
+ # DDIM-like prediction of x0
656
+ pred_x0 = self .pred_x0 (latents , noise_pred , t )
657
+ # get the stored attention maps
658
+ cond_attn = store_processor .attention_probs
659
+ # self-attention-based degrading of latents
660
+ degraded_latents = self .sag_masking (
661
+ pred_x0 , cond_attn , map_size , t , self .pred_epsilon (latents , noise_pred , t )
662
+ )
663
+ # forward and give guidance
664
+ degraded_pred = self .unet (degraded_latents , t , encoder_hidden_states = prompt_embeds ).sample
665
+ noise_pred += sag_scale * (noise_pred - degraded_pred )
666
+
667
+ # compute the previous noisy sample x_t -> x_t-1
668
+ latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs ).prev_sample
669
+
670
+ # call the callback, if provided
671
+ if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
672
+ progress_bar .update ()
673
+ if callback is not None and i % callback_steps == 0 :
674
+ callback (i , t , latents )
667
675
668
676
# 8. Post-processing
669
677
image = self .decode_latents (latents )
@@ -680,20 +688,22 @@ def __call__(
680
688
681
689
return StableDiffusionPipelineOutput (images = image , nsfw_content_detected = has_nsfw_concept )
682
690
683
- def sag_masking (self , original_latents , attn_map , t , eps ):
691
+ def sag_masking (self , original_latents , attn_map , map_size , t , eps ):
684
692
# Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf
685
693
bh , hw1 , hw2 = attn_map .shape
686
694
b , latent_channel , latent_h , latent_w = original_latents .shape
687
695
h = self .unet .attention_head_dim
688
696
if isinstance (h , list ):
689
697
h = h [- 1 ]
690
- map_size = math .isqrt (hw1 )
691
698
692
699
# Produce attention mask
693
700
attn_map = attn_map .reshape (b , h , hw1 , hw2 )
694
701
attn_mask = attn_map .mean (1 , keepdim = False ).sum (1 , keepdim = False ) > 1.0
695
702
attn_mask = (
696
- attn_mask .reshape (b , map_size , map_size ).unsqueeze (1 ).repeat (1 , latent_channel , 1 , 1 ).type (attn_map .dtype )
703
+ attn_mask .reshape (b , map_size [0 ], map_size [1 ])
704
+ .unsqueeze (1 )
705
+ .repeat (1 , latent_channel , 1 , 1 )
706
+ .type (attn_map .dtype )
697
707
)
698
708
attn_mask = F .interpolate (attn_mask , (latent_h , latent_w ))
699
709
0 commit comments