Skip to content

T5Attention support for cross-attention #2654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 15, 2023

Conversation

kashif
Copy link
Contributor

@kashif kashif commented Mar 13, 2023

Added support for implementing T5Attention to processors. Needed for #1044

Tested on pytorch 2.0 RC

Fix use of AttnProcessor2_0 for cross attention with mask
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 13, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

Can you give a bit more background on what issue is fixed here? I'm not so sure about this tbh.
This Torch 2.0 processor should exactly correspond to:

batch_size, sequence_length, _ = hidden_states.shape
IMO

@kashif
Copy link
Contributor Author

kashif commented Mar 13, 2023

Of course!! my bad!

The issue is that the shape of the mask returned by prepare_attention_mask depends on the sequence_length which is currently wrong in the cross-attention case when the length of the encoder_hidden_states differ from the length of the hidden_states

With this change at least the function doesn't complain... however the outputs vs. CrossAttnProcessor still seem to differ so i am trying to debug that and potentially add tests

@Birch-san
Copy link
Contributor

@kashif worth knowing that I'm also working on support for masks, particularly supporting them for cross-attention:
#2634

probably mine will need a rewrite.

I think the idea in your PR is correct, that the length of the sequence needs to be based on the key. I did the same thing in my PR.

@kashif
Copy link
Contributor Author

kashif commented Mar 14, 2023

thanks @Birch-san i am happy to close this in view of your PR. I also need to add two extra flags for scale and bias_out which I need for the T5Attention implementation. Should I just contribute to your PR?

@kashif kashif changed the title Fix AttnProcessor2_0 T5Attention support for cross-attention Mar 14, 2023
@patrickvonplaten
Copy link
Contributor

Cool, this works thanks a lot for making the changes @kashif !

@patrickvonplaten
Copy link
Contributor

@Birch-san - think we could adapt your PR after this quite easily no?

if processor is None:
processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor()
if torch.torch_version.TorchVersion(torch.__version__) >= (2, 1, 0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's revert this, we don't need 2.1, 2.0 is enough and I think the logic before was good

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right but then the scaled_dot_product_attention in 2.0 has no scale which is what i would need... but yes i can deal with that in the pipeline?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, ok I think it's fine if Torch 2.0 doesn't work yet for the spectrogram model. Let's maybe just advertise it with the previous PyTorch version and see if the community tries it out on Pytorch 2.0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok cool! reverting... i can deal with it or i can also check if attn.scale == 1 and not do this... which is only for spectrogram for now?

@@ -497,7 +511,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No

# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this backwards compatible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes since if scale=None, the default scale is used ie. the 1/sqrt(D) but only works in 2.1 nightly

@patrickvonplaten
Copy link
Contributor

@kashif can you also run all the slow tests for:

  • Stable Diffusion
  • Stable Diffusion 2
  • UnCLIP

So that we can be sure that nothing is broken

@kashif
Copy link
Contributor Author

kashif commented Mar 14, 2023

ok sure reverting and running slow tests... give me a few!

@kashif
Copy link
Contributor Author

kashif commented Mar 14, 2023

ran slow tests... all failures are of this example:

FAILED tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py::StableDiffusionLatentUpscalePipelineFastTests::test_attention_slicing_forward_pass - AssertionError: 0.003142655 not less than 0.001 : Attention slicing should not affect the inference results
FAILED tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py::StableDiffusionLatentUpscalePipelineFastTests::test_inference_batch_single_identical - assert 0.004434526 < 0.0001
FAILED tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py::StableDiffusionLatentUpscalePipelineFastTests::test_xformers_attention_forwardGenerator_pass - AssertionError: 0.0036953688 not less than 0.0001 : XFormers attention should not affect the inference results
FAILED tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py::StableDiffusionLatentUpscalePipelineIntegrationTests::test_latent_upscaler_fp16 - AssertionError: assert 0.5175781 < 0.5

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! Thanks for the PR @kashif :-)

@patrickvonplaten patrickvonplaten merged commit cf4227c into huggingface:main Mar 15, 2023
@kashif kashif deleted the fix-AttnProcessor2_0 branch March 15, 2023 17:06
@kashif
Copy link
Contributor Author

kashif commented Mar 15, 2023

ok, thanks! will add fast tests to spectrogram diffusion!

w4ffl35 pushed a commit to w4ffl35/diffusers that referenced this pull request Apr 14, 2023
* fix AttnProcessor2_0

Fix use of AttnProcessor2_0 for cross attention with mask

* added scale_qk and out_bias flags

* fixed for xformers

* check if it has scale argument

* Update cross_attention.py

* check torch version

* fix sliced attn

* style

* set scale

* fix test

* fixed addedKV processor

* revert back AttnProcessor2_0

* if missing if

* fix inner_dim

---------

Co-authored-by: Patrick von Platen <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* fix AttnProcessor2_0

Fix use of AttnProcessor2_0 for cross attention with mask

* added scale_qk and out_bias flags

* fixed for xformers

* check if it has scale argument

* Update cross_attention.py

* check torch version

* fix sliced attn

* style

* set scale

* fix test

* fixed addedKV processor

* revert back AttnProcessor2_0

* if missing if

* fix inner_dim

---------

Co-authored-by: Patrick von Platen <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* fix AttnProcessor2_0

Fix use of AttnProcessor2_0 for cross attention with mask

* added scale_qk and out_bias flags

* fixed for xformers

* check if it has scale argument

* Update cross_attention.py

* check torch version

* fix sliced attn

* style

* set scale

* fix test

* fixed addedKV processor

* revert back AttnProcessor2_0

* if missing if

* fix inner_dim

---------

Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants