-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Flax memory efficient attention #2889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@@ -296,6 +296,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
use_auth_token = kwargs.pop("use_auth_token", None) | |||
revision = kwargs.pop("revision", None) | |||
from_pt = kwargs.pop("from_pt", False) | |||
use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I decided to enable this at pipeline load time. It's not straightforward to replicate the recursive logic we created for xformers
because Flax submodules are not available unless you are in apply
, so I didn't find a way to make functions to enable/disable the setting on demand. Instead, the configuration is set up when the pipeline is instantiated.
Another alternative would have been to pass use_memory_efficient_attention
as an additional argument to generate
so it's applied on a per-inference basis. I thought making it constant for the pipeline made sense for the Flax case.
Any other thoughts about this @patrickvonplaten, @yiyixuxu, @williamberman?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion about this - I think passing it as an argument is probably easier to implement, but what you already did is good:)
As mentioned in #2231, inference is slower when memory efficient attention is enabled, but it allows additional use-cases (larger batch sizes, resolution bucketing). See #2231 (review), #2231 (comment) |
@@ -0,0 +1,101 @@ | |||
import functools |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need a new file for this? I think we should just put it in attention_flax
. Currently we don't really have any "utility" files
@@ -108,19 +138,26 @@ class FlaxBasicTransformerBlock(nn.Module): | |||
Whether to only apply cross attention. | |||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |||
Parameters `dtype` | |||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) | ||
attention_scores = attention_scores * self.scale | ||
attention_probs = nn.softmax(attention_scores, axis=2) | ||
if self.use_memory_efficient_attention: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, think we should however move the memory efficient attention function directly to the attention file.
Also would be nice to have some tests and docs for this (but not urgent really IMO)
@patrickvonplaten any ideas what kind of test should i add? |
A test that this new attention matches the old attention more or less in output result :-) |
@@ -296,6 +296,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
use_auth_token = kwargs.pop("use_auth_token", None) | |||
revision = kwargs.pop("revision", None) | |||
from_pt = kwargs.pop("from_pt", False) | |||
use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion about this - I think passing it as an argument is probably easier to implement, but what you already did is good:)
seems correct up to 7 decimal @patrickvonplaten
|
I wrote a slow pipeline test, let me know if that's ok and we can merge @patrickvonplaten @yiyixuxu |
@pcuenca happy to merge here! |
* add use_memory_efficient params placeholder * test * add memory efficient attention jax * add memory efficient attention jax * newline * forgot dot * Rename use_memory_efficient * Keep dtype last. * Actually use key_chunk_size * Rename symbol * Apply style * Rename use_memory_efficient * Keep dtype last * Pass `use_memory_efficient_attention` in `from_pretrained` * Move JAX memory efficient attention to attention_flax. * Simple test. * style --------- Co-authored-by: muhammad_hanif <[email protected]> Co-authored-by: MuhHanif <[email protected]>
* add use_memory_efficient params placeholder * test * add memory efficient attention jax * add memory efficient attention jax * newline * forgot dot * Rename use_memory_efficient * Keep dtype last. * Actually use key_chunk_size * Rename symbol * Apply style * Rename use_memory_efficient * Keep dtype last * Pass `use_memory_efficient_attention` in `from_pretrained` * Move JAX memory efficient attention to attention_flax. * Simple test. * style --------- Co-authored-by: muhammad_hanif <[email protected]> Co-authored-by: MuhHanif <[email protected]>
* add use_memory_efficient params placeholder * test * add memory efficient attention jax * add memory efficient attention jax * newline * forgot dot * Rename use_memory_efficient * Keep dtype last. * Actually use key_chunk_size * Rename symbol * Apply style * Rename use_memory_efficient * Keep dtype last * Pass `use_memory_efficient_attention` in `from_pretrained` * Move JAX memory efficient attention to attention_flax. * Simple test. * style --------- Co-authored-by: muhammad_hanif <[email protected]> Co-authored-by: MuhHanif <[email protected]>
* add use_memory_efficient params placeholder * test * add memory efficient attention jax * add memory efficient attention jax * newline * forgot dot * Rename use_memory_efficient * Keep dtype last. * Actually use key_chunk_size * Rename symbol * Apply style * Rename use_memory_efficient * Keep dtype last * Pass `use_memory_efficient_attention` in `from_pretrained` * Move JAX memory efficient attention to attention_flax. * Simple test. * style --------- Co-authored-by: muhammad_hanif <[email protected]> Co-authored-by: MuhHanif <[email protected]>
Continues work from #2231. @MuhHanif, I had to open this new PR because yours is in
main
. You're still the main author, of course. I just applied the comments we discussed in #2231 and a couple other changes.