Skip to content

add: a warning message when using xformers in a PT 2.0 env. #3365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Callable, Optional, Union

import torch
Expand Down Expand Up @@ -72,7 +73,8 @@ def __init__(
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax

self.scale = dim_head**-0.5 if scale_qk else 1.0
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
Comment on lines -75 to +77
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change? (Just curious)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot use the scale_qk parameter in the latter methods otherwise. This parameter is for now needed to set the default attention processor to AttnProcessor2_0.

I followed how we are setting processor in the init for that.


self.heads = heads
# for slice_size > 0 the attention score computation
Expand Down Expand Up @@ -140,7 +142,7 @@ def __init__(
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
if processor is None:
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else AttnProcessor()
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)
self.set_processor(processor)

Expand Down Expand Up @@ -176,6 +178,11 @@ def set_use_memory_efficient_attention_xformers(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk:
warnings.warn(
"You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
"We will default to PyTorch's native efficient flash attention implementation provided by PyTorch 2.0."
)
else:
try:
# Make sure we can run the memory efficient attention
Expand Down Expand Up @@ -229,7 +236,15 @@ def set_use_memory_efficient_attention_xformers(
if hasattr(self.processor, "to_k_custom_diffusion"):
processor.to(self.processor.to_k_custom_diffusion.weight.device)
else:
processor = AttnProcessor()
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
processor = (
AttnProcessor2_0()
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
else AttnProcessor()
)

self.set_processor(processor)

Expand All @@ -244,7 +259,13 @@ def set_attention_slice(self, slice_size):
elif self.added_kv_proj_dim is not None:
processor = AttnAddedKVProcessor()
else:
processor = AttnProcessor()
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)

self.set_processor(processor)

Expand Down