11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- import math
14
+ from contextlib import ContextDecorator
15
15
from typing import Callable , Optional
16
16
17
17
import torch
18
18
import torch .nn .functional as F
19
19
from torch import nn
20
20
21
+ from ..utils import deprecate
21
22
from ..utils .import_utils import is_xformers_available
22
- from .cross_attention import CrossAttention
23
+ from .cross_attention import CrossAttention , SpatialAttnProcessor , XFormersSpatialAttnProcessor
23
24
from .embeddings import CombinedTimestepLabelEmbeddings
24
25
25
26
@@ -57,6 +58,20 @@ def __init__(
57
58
eps : float = 1e-5 ,
58
59
):
59
60
super ().__init__ ()
61
+
62
+ if _assert_no_deprecated_attention_blocks > 0 :
63
+ raise AssertionError (
64
+ "Deprecated `AttentionBlock` created while `assert_no_deprecated_attention_blocks` context manager"
65
+ " active."
66
+ )
67
+
68
+ deprecation_message = (
69
+ "AttentionBlock has been deprecated and will be replaced with CrossAttention. TODO add upgrade"
70
+ " instructions"
71
+ )
72
+
73
+ deprecate ("AttentionBlock" , "1.0.0" , deprecation_message , standard_warn = True )
74
+
60
75
self .channels = channels
61
76
62
77
self .num_heads = channels // num_head_channels if num_head_channels is not None else 1
@@ -74,20 +89,6 @@ def __init__(
74
89
self ._use_memory_efficient_attention_xformers = False
75
90
self ._attention_op = None
76
91
77
- def reshape_heads_to_batch_dim (self , tensor ):
78
- batch_size , seq_len , dim = tensor .shape
79
- head_size = self .num_heads
80
- tensor = tensor .reshape (batch_size , seq_len , head_size , dim // head_size )
81
- tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size * head_size , seq_len , dim // head_size )
82
- return tensor
83
-
84
- def reshape_batch_dim_to_heads (self , tensor ):
85
- batch_size , seq_len , dim = tensor .shape
86
- head_size = self .num_heads
87
- tensor = tensor .reshape (batch_size // head_size , head_size , seq_len , dim )
88
- tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size // head_size , seq_len , dim * head_size )
89
- return tensor
90
-
91
92
def set_use_memory_efficient_attention_xformers (
92
93
self , use_memory_efficient_attention_xformers : bool , attention_op : Optional [Callable ] = None
93
94
):
@@ -119,59 +120,43 @@ def set_use_memory_efficient_attention_xformers(
119
120
self ._attention_op = attention_op
120
121
121
122
def forward (self , hidden_states ):
122
- residual = hidden_states
123
- batch , channel , height , width = hidden_states .shape
124
-
125
- # norm
126
- hidden_states = self .group_norm (hidden_states )
127
-
128
- hidden_states = hidden_states .view (batch , channel , height * width ).transpose (1 , 2 )
129
-
130
- # proj to q, k, v
131
- query_proj = self .query (hidden_states )
132
- key_proj = self .key (hidden_states )
133
- value_proj = self .value (hidden_states )
134
-
135
- scale = 1 / math .sqrt (self .channels / self .num_heads )
123
+ attn = self .as_cross_attention ()
124
+ hidden_states = attn (hidden_states )
136
125
137
- query_proj = self .reshape_heads_to_batch_dim (query_proj )
138
- key_proj = self .reshape_heads_to_batch_dim (key_proj )
139
- value_proj = self .reshape_heads_to_batch_dim (value_proj )
126
+ return hidden_states
140
127
141
- if self ._use_memory_efficient_attention_xformers :
142
- # Memory efficient attention
143
- hidden_states = xformers .ops .memory_efficient_attention (
144
- query_proj , key_proj , value_proj , attn_bias = None , op = self ._attention_op
145
- )
146
- hidden_states = hidden_states .to (query_proj .dtype )
128
+ def as_cross_attention (self ):
129
+ if self ._attention_op is None :
130
+ processor = SpatialAttnProcessor ()
147
131
else :
148
- attention_scores = torch .baddbmm (
149
- torch .empty (
150
- query_proj .shape [0 ],
151
- query_proj .shape [1 ],
152
- key_proj .shape [1 ],
153
- dtype = query_proj .dtype ,
154
- device = query_proj .device ,
155
- ),
156
- query_proj ,
157
- key_proj .transpose (- 1 , - 2 ),
158
- beta = 0 ,
159
- alpha = scale ,
160
- )
161
- attention_probs = torch .softmax (attention_scores .float (), dim = - 1 ).type (attention_scores .dtype )
162
- hidden_states = torch .bmm (attention_probs , value_proj )
132
+ processor = XFormersSpatialAttnProcessor (self ._attention_op )
163
133
164
- # reshape hidden_states
165
- hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
166
-
167
- # compute next hidden_states
168
- hidden_states = self .proj_attn (hidden_states )
134
+ if self .num_head_size is None :
135
+ # When `self.num_head_size` is None, there is a single attention head
136
+ # of all the channels
137
+ dim_head = self .channels
138
+ else :
139
+ dim_head = self .num_head_size
140
+
141
+ attn = CrossAttention (
142
+ self .channels ,
143
+ heads = self .num_heads ,
144
+ dim_head = dim_head ,
145
+ bias = True ,
146
+ upcast_softmax = True ,
147
+ norm_num_groups = self .group_norm .num_groups ,
148
+ processor = processor ,
149
+ eps = self .group_norm .eps ,
150
+ rescale_output_factor = self .rescale_output_factor ,
151
+ )
169
152
170
- hidden_states = hidden_states .transpose (- 1 , - 2 ).reshape (batch , channel , height , width )
153
+ attn .group_norm = self .group_norm
154
+ attn .to_q = self .query
155
+ attn .to_k = self .key
156
+ attn .to_v = self .value
157
+ attn .to_out [0 ] = self .proj_attn
171
158
172
- # res connect and rescale
173
- hidden_states = (hidden_states + residual ) / self .rescale_output_factor
174
- return hidden_states
159
+ return attn
175
160
176
161
177
162
class BasicTransformerBlock (nn .Module ):
@@ -480,3 +465,18 @@ def forward(self, x, timestep, class_labels, hidden_dtype=None):
480
465
shift_msa , scale_msa , gate_msa , shift_mlp , scale_mlp , gate_mlp = emb .chunk (6 , dim = 1 )
481
466
x = self .norm (x ) * (1 + scale_msa [:, None ]) + shift_msa [:, None ]
482
467
return x , gate_msa , shift_mlp , scale_mlp , gate_mlp
468
+
469
+
470
+ # tracks the number of `assert_no_deprecated_attention_blocks` decorators
471
+ _assert_no_deprecated_attention_blocks = 0
472
+
473
+
474
+ class assert_no_deprecated_attention_blocks (ContextDecorator ):
475
+ def __enter__ (self ):
476
+ global _assert_no_deprecated_attention_blocks
477
+ _assert_no_deprecated_attention_blocks += 1
478
+ return self
479
+
480
+ def __exit__ (self , * args ):
481
+ global _assert_no_deprecated_attention_blocks
482
+ _assert_no_deprecated_attention_blocks -= 1
0 commit comments