-
Notifications
You must be signed in to change notification settings - Fork 5.9k
T5Attention support for cross-attention #2654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
T5Attention support for cross-attention #2654
Conversation
Fix use of AttnProcessor2_0 for cross attention with mask
The documentation is not available anymore as the PR was closed or merged. |
Can you give a bit more background on what issue is fixed here? I'm not so sure about this tbh.
|
Of course!! my bad! The issue is that the shape of the mask returned by With this change at least the function doesn't complain... however the outputs vs. |
thanks @Birch-san i am happy to close this in view of your PR. I also need to add two extra flags for |
Cool, this works thanks a lot for making the changes @kashif ! |
@Birch-san - think we could adapt your PR after this quite easily no? |
if processor is None: | ||
processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() | ||
if torch.torch_version.TorchVersion(torch.__version__) >= (2, 1, 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's revert this, we don't need 2.1, 2.0 is enough and I think the logic before was good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right but then the scaled_dot_product_attention
in 2.0 has no scale
which is what i would need... but yes i can deal with that in the pipeline?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see, ok I think it's fine if Torch 2.0 doesn't work yet for the spectrogram model. Let's maybe just advertise it with the previous PyTorch version and see if the community tries it out on Pytorch 2.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok cool! reverting... i can deal with it or i can also check if attn.scale == 1
and not do this... which is only for spectrogram for now?
@@ -497,7 +511,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No | |||
|
|||
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |||
hidden_states = F.scaled_dot_product_attention( | |||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this backwards compatible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes since if scale=None
, the default scale is used ie. the 1/sqrt(D) but only works in 2.1 nightly
@kashif can you also run all the slow tests for:
So that we can be sure that nothing is broken |
ok sure reverting and running slow tests... give me a few! |
ran slow tests... all failures are of this example:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool! Thanks for the PR @kashif :-)
ok, thanks! will add fast tests to spectrogram diffusion! |
* fix AttnProcessor2_0 Fix use of AttnProcessor2_0 for cross attention with mask * added scale_qk and out_bias flags * fixed for xformers * check if it has scale argument * Update cross_attention.py * check torch version * fix sliced attn * style * set scale * fix test * fixed addedKV processor * revert back AttnProcessor2_0 * if missing if * fix inner_dim --------- Co-authored-by: Patrick von Platen <[email protected]>
* fix AttnProcessor2_0 Fix use of AttnProcessor2_0 for cross attention with mask * added scale_qk and out_bias flags * fixed for xformers * check if it has scale argument * Update cross_attention.py * check torch version * fix sliced attn * style * set scale * fix test * fixed addedKV processor * revert back AttnProcessor2_0 * if missing if * fix inner_dim --------- Co-authored-by: Patrick von Platen <[email protected]>
* fix AttnProcessor2_0 Fix use of AttnProcessor2_0 for cross attention with mask * added scale_qk and out_bias flags * fixed for xformers * check if it has scale argument * Update cross_attention.py * check torch version * fix sliced attn * style * set scale * fix test * fixed addedKV processor * revert back AttnProcessor2_0 * if missing if * fix inner_dim --------- Co-authored-by: Patrick von Platen <[email protected]>
Added support for implementing T5Attention to processors. Needed for #1044
Tested on pytorch 2.0 RC