-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Flax memory efficient attention #2889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6184815
e2d5708
c2221d6
9347cc5
31d96a0
cb4c8ab
eac25e0
99d88e6
2255794
f40df28
a4ff0b7
00803e9
9560371
26ba0c4
4b82e44
78a106e
41ea7c2
f32f331
00be593
b5544f2
8d35f09
952293c
d1490bc
ed00301
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,10 +12,110 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import functools | ||
import math | ||
|
||
import flax.linen as nn | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
|
||
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): | ||
"""Multi-head dot product attention with a limited number of queries.""" | ||
num_kv, num_heads, k_features = key.shape[-3:] | ||
v_features = value.shape[-1] | ||
key_chunk_size = min(key_chunk_size, num_kv) | ||
query = query / jnp.sqrt(k_features) | ||
|
||
@functools.partial(jax.checkpoint, prevent_cse=False) | ||
def summarize_chunk(query, key, value): | ||
attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision) | ||
|
||
max_score = jnp.max(attn_weights, axis=-1, keepdims=True) | ||
max_score = jax.lax.stop_gradient(max_score) | ||
exp_weights = jnp.exp(attn_weights - max_score) | ||
|
||
exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision) | ||
max_score = jnp.einsum("...qhk->...qh", max_score) | ||
|
||
return (exp_values, exp_weights.sum(axis=-1), max_score) | ||
|
||
def chunk_scanner(chunk_idx): | ||
# julienne key array | ||
key_chunk = jax.lax.dynamic_slice( | ||
operand=key, | ||
start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d] | ||
slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d] | ||
) | ||
|
||
# julienne value array | ||
value_chunk = jax.lax.dynamic_slice( | ||
operand=value, | ||
start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d] | ||
slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d] | ||
) | ||
|
||
return summarize_chunk(query, key_chunk, value_chunk) | ||
|
||
chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) | ||
|
||
global_max = jnp.max(chunk_max, axis=0, keepdims=True) | ||
max_diffs = jnp.exp(chunk_max - global_max) | ||
|
||
chunk_values *= jnp.expand_dims(max_diffs, axis=-1) | ||
chunk_weights *= max_diffs | ||
|
||
all_values = chunk_values.sum(axis=0) | ||
all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) | ||
|
||
return all_values / all_weights | ||
|
||
|
||
def jax_memory_efficient_attention( | ||
query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096 | ||
): | ||
r""" | ||
Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2 | ||
https://github.com/AminRezaei0x443/memory-efficient-attention | ||
|
||
Args: | ||
query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head) | ||
key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head) | ||
value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head) | ||
precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`): | ||
numerical precision for computation | ||
query_chunk_size (`int`, *optional*, defaults to 1024): | ||
chunk size to divide query array value must divide query_length equally without remainder | ||
key_chunk_size (`int`, *optional*, defaults to 4096): | ||
chunk size to divide key and value array value must divide key_value_length equally without remainder | ||
|
||
Returns: | ||
(`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head) | ||
""" | ||
num_q, num_heads, q_features = query.shape[-3:] | ||
|
||
def chunk_scanner(chunk_idx, _): | ||
# julienne query array | ||
query_chunk = jax.lax.dynamic_slice( | ||
operand=query, | ||
start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d] | ||
slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d] | ||
) | ||
|
||
return ( | ||
chunk_idx + query_chunk_size, # unused ignore it | ||
_query_chunk_attention( | ||
query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size | ||
), | ||
) | ||
|
||
_, res = jax.lax.scan( | ||
f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter | ||
) | ||
|
||
return jnp.concatenate(res, axis=-3) # fuse the chunked result back | ||
|
||
|
||
class FlaxAttention(nn.Module): | ||
r""" | ||
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 | ||
|
@@ -29,6 +129,8 @@ class FlaxAttention(nn.Module): | |
Hidden states dimension inside each head | ||
dropout (:obj:`float`, *optional*, defaults to 0.0): | ||
Dropout rate | ||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): | ||
enable memory efficient attention https://arxiv.org/abs/2112.05682 | ||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | ||
Parameters `dtype` | ||
|
||
|
@@ -37,6 +139,7 @@ class FlaxAttention(nn.Module): | |
heads: int = 8 | ||
dim_head: int = 64 | ||
dropout: float = 0.0 | ||
use_memory_efficient_attention: bool = False | ||
dtype: jnp.dtype = jnp.float32 | ||
|
||
def setup(self): | ||
|
@@ -77,13 +180,38 @@ def __call__(self, hidden_states, context=None, deterministic=True): | |
key_states = self.reshape_heads_to_batch_dim(key_proj) | ||
value_states = self.reshape_heads_to_batch_dim(value_proj) | ||
|
||
# compute attentions | ||
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) | ||
attention_scores = attention_scores * self.scale | ||
attention_probs = nn.softmax(attention_scores, axis=2) | ||
if self.use_memory_efficient_attention: | ||
query_states = query_states.transpose(1, 0, 2) | ||
key_states = key_states.transpose(1, 0, 2) | ||
value_states = value_states.transpose(1, 0, 2) | ||
|
||
# this if statement create a chunk size for each layer of the unet | ||
# the chunk size is equal to the query_length dimension of the deepest layer of the unet | ||
|
||
flatten_latent_dim = query_states.shape[-3] | ||
if flatten_latent_dim % 64 == 0: | ||
query_chunk_size = int(flatten_latent_dim / 64) | ||
elif flatten_latent_dim % 16 == 0: | ||
query_chunk_size = int(flatten_latent_dim / 16) | ||
elif flatten_latent_dim % 4 == 0: | ||
query_chunk_size = int(flatten_latent_dim / 4) | ||
else: | ||
query_chunk_size = int(flatten_latent_dim) | ||
|
||
hidden_states = jax_memory_efficient_attention( | ||
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 | ||
) | ||
|
||
hidden_states = hidden_states.transpose(1, 0, 2) | ||
else: | ||
# compute attentions | ||
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) | ||
attention_scores = attention_scores * self.scale | ||
attention_probs = nn.softmax(attention_scores, axis=2) | ||
|
||
# attend to values | ||
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) | ||
|
||
# attend to values | ||
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) | ||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states) | ||
hidden_states = self.proj_attn(hidden_states) | ||
return hidden_states | ||
|
@@ -108,19 +236,26 @@ class FlaxBasicTransformerBlock(nn.Module): | |
Whether to only apply cross attention. | ||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | ||
Parameters `dtype` | ||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! |
||
enable memory efficient attention https://arxiv.org/abs/2112.05682 | ||
""" | ||
dim: int | ||
n_heads: int | ||
d_head: int | ||
dropout: float = 0.0 | ||
only_cross_attention: bool = False | ||
dtype: jnp.dtype = jnp.float32 | ||
use_memory_efficient_attention: bool = False | ||
|
||
def setup(self): | ||
# self attention (or cross_attention if only_cross_attention is True) | ||
self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) | ||
self.attn1 = FlaxAttention( | ||
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype | ||
) | ||
# cross attention | ||
self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) | ||
self.attn2 = FlaxAttention( | ||
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype | ||
) | ||
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) | ||
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) | ||
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) | ||
|
@@ -169,6 +304,8 @@ class FlaxTransformer2DModel(nn.Module): | |
only_cross_attention (`bool`, defaults to `False`): tbd | ||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | ||
Parameters `dtype` | ||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): | ||
enable memory efficient attention https://arxiv.org/abs/2112.05682 | ||
""" | ||
in_channels: int | ||
n_heads: int | ||
|
@@ -178,6 +315,7 @@ class FlaxTransformer2DModel(nn.Module): | |
use_linear_projection: bool = False | ||
only_cross_attention: bool = False | ||
dtype: jnp.dtype = jnp.float32 | ||
use_memory_efficient_attention: bool = False | ||
|
||
def setup(self): | ||
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) | ||
|
@@ -202,6 +340,7 @@ def setup(self): | |
dropout=self.dropout, | ||
only_cross_attention=self.only_cross_attention, | ||
dtype=self.dtype, | ||
use_memory_efficient_attention=self.use_memory_efficient_attention, | ||
) | ||
for _ in range(self.depth) | ||
] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -296,6 +296,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
use_auth_token = kwargs.pop("use_auth_token", None) | ||
revision = kwargs.pop("revision", None) | ||
from_pt = kwargs.pop("from_pt", False) | ||
use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I decided to enable this at pipeline load time. It's not straightforward to replicate the recursive logic we created for Another alternative would have been to pass Any other thoughts about this @patrickvonplaten, @yiyixuxu, @williamberman? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't have a strong opinion about this - I think passing it as an argument is probably easier to implement, but what you already did is good:) |
||
dtype = kwargs.pop("dtype", None) | ||
|
||
# 1. Download the checkpoints and configs | ||
|
@@ -451,7 +452,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
loaded_sub_model = cached_folder | ||
|
||
if issubclass(class_obj, FlaxModelMixin): | ||
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype) | ||
loaded_sub_model, loaded_params = load_method( | ||
loadable_folder, | ||
from_pt=from_pt, | ||
use_memory_efficient_attention=use_memory_efficient_attention, | ||
dtype=dtype, | ||
) | ||
params[name] = loaded_params | ||
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): | ||
if from_pt: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!