Skip to content

adapt attention.py to torch 2.0 #2483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed

Conversation

caiqi
Copy link

@caiqi caiqi commented Feb 24, 2023

The AttentionBlock is not adapted to the torch 2.0. When using StableDiffusionLatentUpscalePipeline with 768x768 images, it will raise OOM on 16GB GPU. This PR use F.scaled_dot_product_attention to decrease the memory usage. I tested on Colab this PR can fix the issue. https://colab.research.google.com/drive/1qMwzjweWSUHsYeG932OCECAeA-qkyUjb?usp=sharing .

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 24, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

Hey @caiqi,

Thanks for the PR - I believe that we already merged PyTorch's 2.0 fast attention support in this PR: #2303

Think we can close this one, no? Very sorry for not replying earlier.

@caiqi
Copy link
Author

caiqi commented Mar 8, 2023

Hey @caiqi,

Thanks for the PR - I believe that we already merged PyTorch's 2.0 fast attention support in this PR: #2303

Think we can close this one, no? Very sorry for not replying earlier.

@patrickvonplaten Thanks! It seems this PR #2303 updates the cross_attention.py but not attention.py?
image

I meet the memory issue due to the image decoder part, which replies on AttentionBlock.

@patrickvonplaten
Copy link
Contributor

Hey @caiqi,

yeah the naming is not great here, attention.py make use of cross_attention.py in all its attention computations.

@caiqi
Copy link
Author

caiqi commented Mar 9, 2023

@patrickvonplaten I have tested the latest diffusers code and it seems that attention.py uses its own attention code. The following is the stack track:
image

I have tested in this colab notebook: https://colab.research.google.com/drive/1qMwzjweWSUHsYeG932OCECAeA-qkyUjb?usp=sharing

@patrickvonplaten
Copy link
Contributor

cc @williamberman we should clean this attention logic up to avoid confusion

@williamberman
Copy link
Contributor

williamberman commented Mar 16, 2023

Appreciate it @caiqi :) we're in the process of deprecating AttentionBlock, see #1880

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants