Skip to content

[Flax] added memory efficient attention for U-net #2231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 0 commits into from

Conversation

MuhHanif
Copy link
Contributor

@MuhHanif MuhHanif commented Feb 3, 2023

memory efficient attention implementation for stable diffusion flax, enabling higher batch count / resolution during training or fine tuning.
enable it by adding argument from_pretrained(use_memory_efficient=True) when loading the U-net model

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@MuhHanif
Copy link
Contributor Author

MuhHanif commented Feb 3, 2023

uhh a little help for code quality checks

@patrickvonplaten
Copy link
Contributor

That's very cool! @pcuenca do you maybe have some time to look into it? :-)

@patrickvonplaten
Copy link
Contributor

@pcuenca do you think you maybe have some time to review this PR? :-)

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay. I have been testing this PR (for inference) on a TPU v3-8, and found that speed is a bit slower but batch sizes can be larger. The use of memory-efficient attention is not as beneficial as in the PyTorch world, because jit and pmap already do a great job by default. In fact, I believe that batch sizes could arguably be made larger without memory-efficient attention by partitioning model and data differently. Still, this could be interesting for some use cases.

These are my results:

Diffusers main This PR
Total Batch Size Compile Inference Compile Inference
8 1m 25s 2.47 1m 39s 3.61
32 1m 55s 7.09 2m 07s 10.7
64 2m 13s 13.8 2m 21s 20.2
160 3m 31s 34.4 3m 34s 34.5
192 OOM 4m 13s 59.2
256 OOM 4m 53s 78.0

@MuhHanif Are these results consistent with what you have observed?

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! As stated before, this could be useful for some use cases. The code works and it looks fine to me, I just have a few suggestions:

  • Rename some symbols for consistency with previous naming conventions in the codebase.
  • Would it be possible to enable memory efficient attention after the pipeline has been loaded? We could use something like unet.enable_jax_memory_efficient_attention().
  • Would it be possible to add a couple of simple tests?

I'm curious about the way you are planning to use this feature. Have you measured any training figures by any chance? If so, it could be interesting to add some comments about this feature in the training docs (I can help with that).

Please, let us know if these comments make sense to you. Once again, sorry for being late to review.

(Also, please note that we recently updated our quality/style tooling, let me know if you need help making style checks pass).

@@ -31,13 +32,16 @@ class FlaxAttentionBlock(nn.Module):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
use_memory_efficient (`bool`, *optional*, defaults to `False`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
use_memory_efficient (`bool`, *optional*, defaults to `False`):
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename it to use_memory_efficient_attention everywhere?

"""
query_dim: int
heads: int = 8
dim_head: int = 64
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
use_memory_efficient: bool = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
use_memory_efficient: bool = False
use_memory_efficient_attention: bool = False

key_chunk_size = min(key_chunk_size, num_kv)
query = query / jnp.sqrt(k_features)

@functools.partial(jax.checkpoint, prevent_cse=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I'd be curious to know what's the impact for training. Why do we need to disable prevent_cse?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I'd be curious to know what's the impact for training. Why do we need to disable prevent_cse?

idk, I just take the code snippet from Self-attention Does Not Need O(n2) Memory and inspiration from AminRezaei0x443. I could test it with prevent_cse enabled and see what happens.


return all_values / all_weights

def memory_efficient_attention(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename to jax_memory_efficient_attention for symmetry with xformers_memory_efficient_attention

@MuhHanif
Copy link
Contributor Author

MuhHanif commented Mar 4, 2023

Sorry for the delay. I have been testing this PR (for inference) on a TPU v3-8, and found that speed is a bit slower but batch sizes can be larger. The use of memory-efficient attention is not as beneficial as in the PyTorch world, because jit and pmap already do a great job by default. In fact, I believe that batch sizes could arguably be made larger without memory-efficient attention by partitioning model and data differently. Still, this could be interesting for some use cases.

These are my results:
Diffusers main This PR
Total Batch Size Compile Inference Compile Inference
8 1m 25s 2.47 1m 39s 3.61
32 1m 55s 7.09 2m 07s 10.7
64 2m 13s 13.8 2m 21s 20.2
160 3m 31s 34.4 3m 34s 34.5
192 OOM 4m 13s 59.2
256 OOM 4m 53s 78.0

@MuhHanif Are these results consistent with what you have observed?

Hi @pcuenca ! sorry for late reply,

yes it's slower roughly 10-20% if using memory efficient attention during training (I didn't profile it thoroughly tho).
and the slow down is expected because I chunk the self attention query matrix to be as small as the center most layer of the U-Net.
but it makes training with varying resolution size (NAi resolution bucketing) for each batch possible ie: ( 512x512, 512x704, etc) without OOM

@MuhHanif
Copy link
Contributor Author

MuhHanif commented Mar 4, 2023

* Rename some symbols for consistency with previous naming conventions in the codebase.

okay

* Would it be possible to enable memory efficient attention after the pipeline has been loaded? We could use something like `unet.enable_jax_memory_efficient_attention()`.

it should be possible with unet.use_memory_efficient = True

* Would it be possible to add a couple of simple tests?

what kind of test should i add?

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Mar 28, 2023
@patrickvonplaten
Copy link
Contributor

Gentle ping here @pcuenca

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants