@@ -431,6 +431,10 @@ def norm_encoder_hidden_states(self, encoder_hidden_states):
431
431
432
432
433
433
class AttnProcessor :
434
+ r"""
435
+ Default processor for performing attention-related computations.
436
+ """
437
+
434
438
def __call__ (
435
439
self ,
436
440
attn : Attention ,
@@ -516,6 +520,18 @@ def forward(self, hidden_states):
516
520
517
521
518
522
class LoRAAttnProcessor (nn .Module ):
523
+ r"""
524
+ Processor for implementing the LoRA attention mechanism.
525
+
526
+ Args:
527
+ hidden_size (`int`, *optional*):
528
+ The hidden size of the attention layer.
529
+ cross_attention_dim (`int`, *optional*):
530
+ The number of channels in the `encoder_hidden_states`.
531
+ rank (`int`, defaults to 4):
532
+ The dimension of the LoRA update matrices.
533
+ """
534
+
519
535
def __init__ (self , hidden_size , cross_attention_dim = None , rank = 4 ):
520
536
super ().__init__ ()
521
537
@@ -580,6 +596,24 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
580
596
581
597
582
598
class CustomDiffusionAttnProcessor (nn .Module ):
599
+ r"""
600
+ Processor for implementing attention for the Custom Diffusion method.
601
+
602
+ Args:
603
+ train_kv (`bool`, defaults to `True`):
604
+ Whether to newly train the key and value matrices corresponding to the text features.
605
+ train_q_out (`bool`, defaults to `True`):
606
+ Whether to newly train query matrices corresponding to the latent image features.
607
+ hidden_size (`int`, *optional*, defaults to `None`):
608
+ The hidden size of the attention layer.
609
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
610
+ The number of channels in the `encoder_hidden_states`.
611
+ out_bias (`bool`, defaults to `True`):
612
+ Whether to include the bias parameter in `train_q_out`.
613
+ dropout (`float`, *optional*, defaults to 0.0):
614
+ The dropout probability to use.
615
+ """
616
+
583
617
def __init__ (
584
618
self ,
585
619
train_kv = True ,
@@ -658,6 +692,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
658
692
659
693
660
694
class AttnAddedKVProcessor :
695
+ r"""
696
+ Processor for performing attention-related computations with extra learnable key and value matrices for the text
697
+ encoder.
698
+ """
699
+
661
700
def __call__ (self , attn : Attention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
662
701
residual = hidden_states
663
702
hidden_states = hidden_states .view (hidden_states .shape [0 ], hidden_states .shape [1 ], - 1 ).transpose (1 , 2 )
@@ -707,6 +746,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
707
746
708
747
709
748
class AttnAddedKVProcessor2_0 :
749
+ r"""
750
+ Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
751
+ learnable key and value matrices for the text encoder.
752
+ """
753
+
710
754
def __init__ (self ):
711
755
if not hasattr (F , "scaled_dot_product_attention" ):
712
756
raise ImportError (
@@ -765,6 +809,19 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
765
809
766
810
767
811
class LoRAAttnAddedKVProcessor (nn .Module ):
812
+ r"""
813
+ Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
814
+ encoder.
815
+
816
+ Args:
817
+ hidden_size (`int`, *optional*):
818
+ The hidden size of the attention layer.
819
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
820
+ The number of channels in the `encoder_hidden_states`.
821
+ rank (`int`, defaults to 4):
822
+ The dimension of the LoRA update matrices.
823
+ """
824
+
768
825
def __init__ (self , hidden_size , cross_attention_dim = None , rank = 4 ):
769
826
super ().__init__ ()
770
827
@@ -832,6 +889,17 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
832
889
833
890
834
891
class XFormersAttnProcessor :
892
+ r"""
893
+ Processor for implementing memory efficient attention using xFormers.
894
+
895
+ Args:
896
+ attention_op (`Callable`, *optional*, defaults to `None`):
897
+ The base
898
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
899
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
900
+ operator.
901
+ """
902
+
835
903
def __init__ (self , attention_op : Optional [Callable ] = None ):
836
904
self .attention_op = attention_op
837
905
@@ -905,6 +973,10 @@ def __call__(
905
973
906
974
907
975
class AttnProcessor2_0 :
976
+ r"""
977
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
978
+ """
979
+
908
980
def __init__ (self ):
909
981
if not hasattr (F , "scaled_dot_product_attention" ):
910
982
raise ImportError ("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
@@ -983,6 +1055,23 @@ def __call__(
983
1055
984
1056
985
1057
class LoRAXFormersAttnProcessor (nn .Module ):
1058
+ r"""
1059
+ Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
1060
+
1061
+ Args:
1062
+ hidden_size (`int`, *optional*):
1063
+ The hidden size of the attention layer.
1064
+ cross_attention_dim (`int`, *optional*):
1065
+ The number of channels in the `encoder_hidden_states`.
1066
+ rank (`int`, defaults to 4):
1067
+ The dimension of the LoRA update matrices.
1068
+ attention_op (`Callable`, *optional*, defaults to `None`):
1069
+ The base
1070
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1071
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1072
+ operator.
1073
+ """
1074
+
986
1075
def __init__ (self , hidden_size , cross_attention_dim , rank = 4 , attention_op : Optional [Callable ] = None ):
987
1076
super ().__init__ ()
988
1077
@@ -1049,6 +1138,28 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
1049
1138
1050
1139
1051
1140
class CustomDiffusionXFormersAttnProcessor (nn .Module ):
1141
+ r"""
1142
+ Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
1143
+
1144
+ Args:
1145
+ train_kv (`bool`, defaults to `True`):
1146
+ Whether to newly train the key and value matrices corresponding to the text features.
1147
+ train_q_out (`bool`, defaults to `True`):
1148
+ Whether to newly train query matrices corresponding to the latent image features.
1149
+ hidden_size (`int`, *optional*, defaults to `None`):
1150
+ The hidden size of the attention layer.
1151
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1152
+ The number of channels in the `encoder_hidden_states`.
1153
+ out_bias (`bool`, defaults to `True`):
1154
+ Whether to include the bias parameter in `train_q_out`.
1155
+ dropout (`float`, *optional*, defaults to 0.0):
1156
+ The dropout probability to use.
1157
+ attention_op (`Callable`, *optional*, defaults to `None`):
1158
+ The base
1159
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
1160
+ as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
1161
+ """
1162
+
1052
1163
def __init__ (
1053
1164
self ,
1054
1165
train_kv = True ,
@@ -1134,6 +1245,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
1134
1245
1135
1246
1136
1247
class SlicedAttnProcessor :
1248
+ r"""
1249
+ Processor for implementing sliced attention.
1250
+
1251
+ Args:
1252
+ slice_size (`int`, *optional*):
1253
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1254
+ `attention_head_dim` must be a multiple of the `slice_size`.
1255
+ """
1256
+
1137
1257
def __init__ (self , slice_size ):
1138
1258
self .slice_size = slice_size
1139
1259
@@ -1206,6 +1326,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
1206
1326
1207
1327
1208
1328
class SlicedAttnAddedKVProcessor :
1329
+ r"""
1330
+ Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
1331
+
1332
+ Args:
1333
+ slice_size (`int`, *optional*):
1334
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1335
+ `attention_head_dim` must be a multiple of the `slice_size`.
1336
+ """
1337
+
1209
1338
def __init__ (self , slice_size ):
1210
1339
self .slice_size = slice_size
1211
1340
0 commit comments