11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import warnings
14
15
from typing import Callable , Optional , Union
15
16
16
17
import torch
@@ -72,7 +73,8 @@ def __init__(
72
73
self .upcast_attention = upcast_attention
73
74
self .upcast_softmax = upcast_softmax
74
75
75
- self .scale = dim_head ** - 0.5 if scale_qk else 1.0
76
+ self .scale_qk = scale_qk
77
+ self .scale = dim_head ** - 0.5 if self .scale_qk else 1.0
76
78
77
79
self .heads = heads
78
80
# for slice_size > 0 the attention score computation
@@ -140,7 +142,7 @@ def __init__(
140
142
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
141
143
if processor is None :
142
144
processor = (
143
- AttnProcessor2_0 () if hasattr (F , "scaled_dot_product_attention" ) and scale_qk else AttnProcessor ()
145
+ AttnProcessor2_0 () if hasattr (F , "scaled_dot_product_attention" ) and self . scale_qk else AttnProcessor ()
144
146
)
145
147
self .set_processor (processor )
146
148
@@ -176,6 +178,11 @@ def set_use_memory_efficient_attention_xformers(
176
178
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
177
179
" only available for GPU "
178
180
)
181
+ elif hasattr (F , "scaled_dot_product_attention" ) and self .scale_qk :
182
+ warnings .warn (
183
+ "You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
184
+ "We will default to PyTorch's native efficient flash attention implementation provided by PyTorch 2.0."
185
+ )
179
186
else :
180
187
try :
181
188
# Make sure we can run the memory efficient attention
@@ -229,7 +236,15 @@ def set_use_memory_efficient_attention_xformers(
229
236
if hasattr (self .processor , "to_k_custom_diffusion" ):
230
237
processor .to (self .processor .to_k_custom_diffusion .weight .device )
231
238
else :
232
- processor = AttnProcessor ()
239
+ # set attention processor
240
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
241
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
242
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
243
+ processor = (
244
+ AttnProcessor2_0 ()
245
+ if hasattr (F , "scaled_dot_product_attention" ) and self .scale_qk
246
+ else AttnProcessor ()
247
+ )
233
248
234
249
self .set_processor (processor )
235
250
@@ -244,7 +259,13 @@ def set_attention_slice(self, slice_size):
244
259
elif self .added_kv_proj_dim is not None :
245
260
processor = AttnAddedKVProcessor ()
246
261
else :
247
- processor = AttnProcessor ()
262
+ # set attention processor
263
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
264
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
265
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
266
+ processor = (
267
+ AttnProcessor2_0 () if hasattr (F , "scaled_dot_product_attention" ) and self .scale_qk else AttnProcessor ()
268
+ )
248
269
249
270
self .set_processor (processor )
250
271
0 commit comments