Skip to content

Commit 2101ea8

Browse files
authored
[docs] Add AttnProcessor to docs (huggingface#3474)
* add attnprocessor to docs * fix path to class * create separate page for attnprocessors * fix path * fix path for real * fill in docstrings * apply feedback * apply feedback
1 parent a7fbbe1 commit 2101ea8

File tree

1 file changed

+129
-0
lines changed

1 file changed

+129
-0
lines changed

models/attention_processor.py

+129
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,10 @@ def norm_encoder_hidden_states(self, encoder_hidden_states):
431431

432432

433433
class AttnProcessor:
434+
r"""
435+
Default processor for performing attention-related computations.
436+
"""
437+
434438
def __call__(
435439
self,
436440
attn: Attention,
@@ -516,6 +520,18 @@ def forward(self, hidden_states):
516520

517521

518522
class LoRAAttnProcessor(nn.Module):
523+
r"""
524+
Processor for implementing the LoRA attention mechanism.
525+
526+
Args:
527+
hidden_size (`int`, *optional*):
528+
The hidden size of the attention layer.
529+
cross_attention_dim (`int`, *optional*):
530+
The number of channels in the `encoder_hidden_states`.
531+
rank (`int`, defaults to 4):
532+
The dimension of the LoRA update matrices.
533+
"""
534+
519535
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
520536
super().__init__()
521537

@@ -580,6 +596,24 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
580596

581597

582598
class CustomDiffusionAttnProcessor(nn.Module):
599+
r"""
600+
Processor for implementing attention for the Custom Diffusion method.
601+
602+
Args:
603+
train_kv (`bool`, defaults to `True`):
604+
Whether to newly train the key and value matrices corresponding to the text features.
605+
train_q_out (`bool`, defaults to `True`):
606+
Whether to newly train query matrices corresponding to the latent image features.
607+
hidden_size (`int`, *optional*, defaults to `None`):
608+
The hidden size of the attention layer.
609+
cross_attention_dim (`int`, *optional*, defaults to `None`):
610+
The number of channels in the `encoder_hidden_states`.
611+
out_bias (`bool`, defaults to `True`):
612+
Whether to include the bias parameter in `train_q_out`.
613+
dropout (`float`, *optional*, defaults to 0.0):
614+
The dropout probability to use.
615+
"""
616+
583617
def __init__(
584618
self,
585619
train_kv=True,
@@ -658,6 +692,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
658692

659693

660694
class AttnAddedKVProcessor:
695+
r"""
696+
Processor for performing attention-related computations with extra learnable key and value matrices for the text
697+
encoder.
698+
"""
699+
661700
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
662701
residual = hidden_states
663702
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
@@ -707,6 +746,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
707746

708747

709748
class AttnAddedKVProcessor2_0:
749+
r"""
750+
Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
751+
learnable key and value matrices for the text encoder.
752+
"""
753+
710754
def __init__(self):
711755
if not hasattr(F, "scaled_dot_product_attention"):
712756
raise ImportError(
@@ -765,6 +809,19 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
765809

766810

767811
class LoRAAttnAddedKVProcessor(nn.Module):
812+
r"""
813+
Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
814+
encoder.
815+
816+
Args:
817+
hidden_size (`int`, *optional*):
818+
The hidden size of the attention layer.
819+
cross_attention_dim (`int`, *optional*, defaults to `None`):
820+
The number of channels in the `encoder_hidden_states`.
821+
rank (`int`, defaults to 4):
822+
The dimension of the LoRA update matrices.
823+
"""
824+
768825
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
769826
super().__init__()
770827

@@ -832,6 +889,17 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
832889

833890

834891
class XFormersAttnProcessor:
892+
r"""
893+
Processor for implementing memory efficient attention using xFormers.
894+
895+
Args:
896+
attention_op (`Callable`, *optional*, defaults to `None`):
897+
The base
898+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
899+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
900+
operator.
901+
"""
902+
835903
def __init__(self, attention_op: Optional[Callable] = None):
836904
self.attention_op = attention_op
837905

@@ -905,6 +973,10 @@ def __call__(
905973

906974

907975
class AttnProcessor2_0:
976+
r"""
977+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
978+
"""
979+
908980
def __init__(self):
909981
if not hasattr(F, "scaled_dot_product_attention"):
910982
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@@ -983,6 +1055,23 @@ def __call__(
9831055

9841056

9851057
class LoRAXFormersAttnProcessor(nn.Module):
1058+
r"""
1059+
Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
1060+
1061+
Args:
1062+
hidden_size (`int`, *optional*):
1063+
The hidden size of the attention layer.
1064+
cross_attention_dim (`int`, *optional*):
1065+
The number of channels in the `encoder_hidden_states`.
1066+
rank (`int`, defaults to 4):
1067+
The dimension of the LoRA update matrices.
1068+
attention_op (`Callable`, *optional*, defaults to `None`):
1069+
The base
1070+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1071+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1072+
operator.
1073+
"""
1074+
9861075
def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None):
9871076
super().__init__()
9881077

@@ -1049,6 +1138,28 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
10491138

10501139

10511140
class CustomDiffusionXFormersAttnProcessor(nn.Module):
1141+
r"""
1142+
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
1143+
1144+
Args:
1145+
train_kv (`bool`, defaults to `True`):
1146+
Whether to newly train the key and value matrices corresponding to the text features.
1147+
train_q_out (`bool`, defaults to `True`):
1148+
Whether to newly train query matrices corresponding to the latent image features.
1149+
hidden_size (`int`, *optional*, defaults to `None`):
1150+
The hidden size of the attention layer.
1151+
cross_attention_dim (`int`, *optional*, defaults to `None`):
1152+
The number of channels in the `encoder_hidden_states`.
1153+
out_bias (`bool`, defaults to `True`):
1154+
Whether to include the bias parameter in `train_q_out`.
1155+
dropout (`float`, *optional*, defaults to 0.0):
1156+
The dropout probability to use.
1157+
attention_op (`Callable`, *optional*, defaults to `None`):
1158+
The base
1159+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
1160+
as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
1161+
"""
1162+
10521163
def __init__(
10531164
self,
10541165
train_kv=True,
@@ -1134,6 +1245,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
11341245

11351246

11361247
class SlicedAttnProcessor:
1248+
r"""
1249+
Processor for implementing sliced attention.
1250+
1251+
Args:
1252+
slice_size (`int`, *optional*):
1253+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1254+
`attention_head_dim` must be a multiple of the `slice_size`.
1255+
"""
1256+
11371257
def __init__(self, slice_size):
11381258
self.slice_size = slice_size
11391259

@@ -1206,6 +1326,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
12061326

12071327

12081328
class SlicedAttnAddedKVProcessor:
1329+
r"""
1330+
Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
1331+
1332+
Args:
1333+
slice_size (`int`, *optional*):
1334+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1335+
`attention_head_dim` must be a multiple of the `slice_size`.
1336+
"""
1337+
12091338
def __init__(self, slice_size):
12101339
self.slice_size = slice_size
12111340

0 commit comments

Comments
 (0)