Skip to content

Commit 94a0c64

Browse files
add: a warning message when using xformers in a PT 2.0 env. (#3365)
* add: a warning message when using xformers in a PT 2.0 env. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 26832aa commit 94a0c64

File tree

1 file changed

+25
-4
lines changed

1 file changed

+25
-4
lines changed

Diff for: src/diffusers/models/attention_processor.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
1415
from typing import Callable, Optional, Union
1516

1617
import torch
@@ -72,7 +73,8 @@ def __init__(
7273
self.upcast_attention = upcast_attention
7374
self.upcast_softmax = upcast_softmax
7475

75-
self.scale = dim_head**-0.5 if scale_qk else 1.0
76+
self.scale_qk = scale_qk
77+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
7678

7779
self.heads = heads
7880
# for slice_size > 0 the attention score computation
@@ -140,7 +142,7 @@ def __init__(
140142
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
141143
if processor is None:
142144
processor = (
143-
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else AttnProcessor()
145+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
144146
)
145147
self.set_processor(processor)
146148

@@ -176,6 +178,11 @@ def set_use_memory_efficient_attention_xformers(
176178
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
177179
" only available for GPU "
178180
)
181+
elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk:
182+
warnings.warn(
183+
"You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
184+
"We will default to PyTorch's native efficient flash attention implementation provided by PyTorch 2.0."
185+
)
179186
else:
180187
try:
181188
# Make sure we can run the memory efficient attention
@@ -229,7 +236,15 @@ def set_use_memory_efficient_attention_xformers(
229236
if hasattr(self.processor, "to_k_custom_diffusion"):
230237
processor.to(self.processor.to_k_custom_diffusion.weight.device)
231238
else:
232-
processor = AttnProcessor()
239+
# set attention processor
240+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
241+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
242+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
243+
processor = (
244+
AttnProcessor2_0()
245+
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
246+
else AttnProcessor()
247+
)
233248

234249
self.set_processor(processor)
235250

@@ -244,7 +259,13 @@ def set_attention_slice(self, slice_size):
244259
elif self.added_kv_proj_dim is not None:
245260
processor = AttnAddedKVProcessor()
246261
else:
247-
processor = AttnProcessor()
262+
# set attention processor
263+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
264+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
265+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
266+
processor = (
267+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
268+
)
248269

249270
self.set_processor(processor)
250271

0 commit comments

Comments
 (0)