-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[Flax] added memory efficient attention for U-net #2231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
uhh a little help for code quality checks |
That's very cool! @pcuenca do you maybe have some time to look into it? :-) |
@pcuenca do you think you maybe have some time to review this PR? :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay. I have been testing this PR (for inference) on a TPU v3-8, and found that speed is a bit slower but batch sizes can be larger. The use of memory-efficient attention is not as beneficial as in the PyTorch world, because jit
and pmap
already do a great job by default. In fact, I believe that batch sizes could arguably be made larger without memory-efficient attention by partitioning model and data differently. Still, this could be interesting for some use cases.
These are my results:
Diffusers main |
This PR | |||
---|---|---|---|---|
Total Batch Size | Compile | Inference | Compile | Inference |
8 | 1m 25s | 2.47 | 1m 39s | 3.61 |
32 | 1m 55s | 7.09 | 2m 07s | 10.7 |
64 | 2m 13s | 13.8 | 2m 21s | 20.2 |
160 | 3m 31s | 34.4 | 3m 34s | 34.5 |
192 | OOM | 4m 13s | 59.2 | |
256 | OOM | 4m 53s | 78.0 |
@MuhHanif Are these results consistent with what you have observed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! As stated before, this could be useful for some use cases. The code works and it looks fine to me, I just have a few suggestions:
- Rename some symbols for consistency with previous naming conventions in the codebase.
- Would it be possible to enable memory efficient attention after the pipeline has been loaded? We could use something like
unet.enable_jax_memory_efficient_attention()
. - Would it be possible to add a couple of simple tests?
I'm curious about the way you are planning to use this feature. Have you measured any training figures by any chance? If so, it could be interesting to add some comments about this feature in the training docs (I can help with that).
Please, let us know if these comments make sense to you. Once again, sorry for being late to review.
(Also, please note that we recently updated our quality/style tooling, let me know if you need help making style checks pass).
@@ -31,13 +32,16 @@ class FlaxAttentionBlock(nn.Module): | |||
Dropout rate | |||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |||
Parameters `dtype` | |||
use_memory_efficient (`bool`, *optional*, defaults to `False`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_memory_efficient (`bool`, *optional*, defaults to `False`): | |
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename it to use_memory_efficient_attention
everywhere?
""" | ||
query_dim: int | ||
heads: int = 8 | ||
dim_head: int = 64 | ||
dropout: float = 0.0 | ||
dtype: jnp.dtype = jnp.float32 | ||
use_memory_efficient: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_memory_efficient: bool = False | |
use_memory_efficient_attention: bool = False |
key_chunk_size = min(key_chunk_size, num_kv) | ||
query = query / jnp.sqrt(k_features) | ||
|
||
@functools.partial(jax.checkpoint, prevent_cse=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. I'd be curious to know what's the impact for training. Why do we need to disable prevent_cse
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. I'd be curious to know what's the impact for training. Why do we need to disable
prevent_cse
?
idk, I just take the code snippet from Self-attention Does Not Need O(n2) Memory and inspiration from AminRezaei0x443. I could test it with prevent_cse
enabled and see what happens.
|
||
return all_values / all_weights | ||
|
||
def memory_efficient_attention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename to jax_memory_efficient_attention
for symmetry with xformers_memory_efficient_attention
Hi @pcuenca ! sorry for late reply, yes it's slower roughly 10-20% if using memory efficient attention during training (I didn't profile it thoroughly tho). |
okay
it should be possible with
what kind of test should i add? |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Gentle ping here @pcuenca |
memory efficient attention implementation for stable diffusion flax, enabling higher batch count / resolution during training or fine tuning.
enable it by adding argument
from_pretrained(use_memory_efficient=True)
when loading the U-net model