Skip to content

Flax memory efficient attention #2889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Apr 12, 2023

Conversation

pcuenca
Copy link
Member

@pcuenca pcuenca commented Mar 29, 2023

Continues work from #2231. @MuhHanif, I had to open this new PR because yours is in main. You're still the main author, of course. I just applied the comments we discussed in #2231 and a couple other changes.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 29, 2023

The documentation is not available anymore as the PR was closed or merged.

@@ -296,6 +296,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
from_pt = kwargs.pop("from_pt", False)
use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to enable this at pipeline load time. It's not straightforward to replicate the recursive logic we created for xformers because Flax submodules are not available unless you are in apply, so I didn't find a way to make functions to enable/disable the setting on demand. Instead, the configuration is set up when the pipeline is instantiated.

Another alternative would have been to pass use_memory_efficient_attention as an additional argument to generate so it's applied on a per-inference basis. I thought making it constant for the pipeline made sense for the Flax case.

Any other thoughts about this @patrickvonplaten, @yiyixuxu, @williamberman?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion about this - I think passing it as an argument is probably easier to implement, but what you already did is good:)

@pcuenca
Copy link
Member Author

pcuenca commented Mar 29, 2023

As mentioned in #2231, inference is slower when memory efficient attention is enabled, but it allows additional use-cases (larger batch sizes, resolution bucketing). See #2231 (review), #2231 (comment)

@@ -0,0 +1,101 @@
import functools
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need a new file for this? I think we should just put it in attention_flax . Currently we don't really have any "utility" files

@@ -108,19 +138,26 @@ class FlaxBasicTransformerBlock(nn.Module):
Whether to only apply cross attention.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
attention_scores = attention_scores * self.scale
attention_probs = nn.softmax(attention_scores, axis=2)
if self.use_memory_efficient_attention:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, think we should however move the memory efficient attention function directly to the attention file.

Also would be nice to have some tests and docs for this (but not urgent really IMO)

@MuhHanif
Copy link
Contributor

@patrickvonplaten any ideas what kind of test should i add?

@patrickvonplaten
Copy link
Contributor

A test that this new attention matches the old attention more or less in output result :-)

@@ -296,6 +296,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
from_pt = kwargs.pop("from_pt", False)
use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion about this - I think passing it as an argument is probably easier to implement, but what you already did is good:)

@MuhHanif
Copy link
Contributor

MuhHanif commented Apr 11, 2023

seems correct up to 7 decimal @patrickvonplaten

A test that this new attention matches the old attention more or less in output result :-)

import jax
import jax.numpy as jnp    
import numpy as np
from flax import linen as nn

# init model 
unet_attn_init = diffusers.models.attention_flax.FlaxCrossAttention(8*320)

#dummy inputs
key = jax.random.PRNGKey(42)
key, rand = jax.random.split(key)
context = jax.random.normal(rand, [1,64*64, 320])

# generate random params
unet_attn_params = unet_attn_init.init(key, context)


with jax.default_matmul_precision("float32"):
    # model without memory efficient
    unet_attn = diffusers.models.attention_flax.FlaxCrossAttention(8*320)
    b = unet_attn.apply(unet_attn_params, hidden_states = context)
    print(unet_attn)

    # model with memory efficient
    unet_attn_eff = diffusers.models.attention_flax.FlaxCrossAttention(8*320, use_memory_efficient=True)
    a = unet_attn_eff.apply(unet_attn_params, hidden_states = context)
    print(unet_attn_eff)


np.testing.assert_almost_equal(np.array(a), np.array(b), decimal=7)```


@pcuenca
Copy link
Member Author

pcuenca commented Apr 11, 2023

I wrote a slow pipeline test, let me know if that's ok and we can merge @patrickvonplaten @yiyixuxu

@patrickvonplaten
Copy link
Contributor

@pcuenca happy to merge here!

@patrickvonplaten patrickvonplaten merged commit dc27750 into main Apr 12, 2023
@patrickvonplaten patrickvonplaten deleted the 2231-flax-memory-efficient-attention branch April 12, 2023 09:18
w4ffl35 pushed a commit to w4ffl35/diffusers that referenced this pull request Apr 14, 2023
* add use_memory_efficient params placeholder

* test

* add memory efficient attention jax

* add memory efficient attention jax

* newline

* forgot dot

* Rename use_memory_efficient

* Keep dtype last.

* Actually use key_chunk_size

* Rename symbol

* Apply style

* Rename use_memory_efficient

* Keep dtype last

* Pass `use_memory_efficient_attention` in `from_pretrained`

* Move JAX memory efficient attention to attention_flax.

* Simple test.

* style

---------

Co-authored-by: muhammad_hanif <[email protected]>
Co-authored-by: MuhHanif <[email protected]>
dg845 pushed a commit to dg845/diffusers that referenced this pull request May 6, 2023
* add use_memory_efficient params placeholder

* test

* add memory efficient attention jax

* add memory efficient attention jax

* newline

* forgot dot

* Rename use_memory_efficient

* Keep dtype last.

* Actually use key_chunk_size

* Rename symbol

* Apply style

* Rename use_memory_efficient

* Keep dtype last

* Pass `use_memory_efficient_attention` in `from_pretrained`

* Move JAX memory efficient attention to attention_flax.

* Simple test.

* style

---------

Co-authored-by: muhammad_hanif <[email protected]>
Co-authored-by: MuhHanif <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* add use_memory_efficient params placeholder

* test

* add memory efficient attention jax

* add memory efficient attention jax

* newline

* forgot dot

* Rename use_memory_efficient

* Keep dtype last.

* Actually use key_chunk_size

* Rename symbol

* Apply style

* Rename use_memory_efficient

* Keep dtype last

* Pass `use_memory_efficient_attention` in `from_pretrained`

* Move JAX memory efficient attention to attention_flax.

* Simple test.

* style

---------

Co-authored-by: muhammad_hanif <[email protected]>
Co-authored-by: MuhHanif <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add use_memory_efficient params placeholder

* test

* add memory efficient attention jax

* add memory efficient attention jax

* newline

* forgot dot

* Rename use_memory_efficient

* Keep dtype last.

* Actually use key_chunk_size

* Rename symbol

* Apply style

* Rename use_memory_efficient

* Keep dtype last

* Pass `use_memory_efficient_attention` in `from_pretrained`

* Move JAX memory efficient attention to attention_flax.

* Simple test.

* style

---------

Co-authored-by: muhammad_hanif <[email protected]>
Co-authored-by: MuhHanif <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants