Skip to content

Flax memory efficient attention #2889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 147 additions & 8 deletions src/diffusers/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,110 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import math

import flax.linen as nn
import jax
import jax.numpy as jnp


def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
"""Multi-head dot product attention with a limited number of queries."""
num_kv, num_heads, k_features = key.shape[-3:]
v_features = value.shape[-1]
key_chunk_size = min(key_chunk_size, num_kv)
query = query / jnp.sqrt(k_features)

@functools.partial(jax.checkpoint, prevent_cse=False)
def summarize_chunk(query, key, value):
attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)

max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
max_score = jax.lax.stop_gradient(max_score)
exp_weights = jnp.exp(attn_weights - max_score)

exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
max_score = jnp.einsum("...qhk->...qh", max_score)

return (exp_values, exp_weights.sum(axis=-1), max_score)

def chunk_scanner(chunk_idx):
# julienne key array
key_chunk = jax.lax.dynamic_slice(
operand=key,
start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
)

# julienne value array
value_chunk = jax.lax.dynamic_slice(
operand=value,
start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
)

return summarize_chunk(query, key_chunk, value_chunk)

chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))

global_max = jnp.max(chunk_max, axis=0, keepdims=True)
max_diffs = jnp.exp(chunk_max - global_max)

chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
chunk_weights *= max_diffs

all_values = chunk_values.sum(axis=0)
all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)

return all_values / all_weights


def jax_memory_efficient_attention(
query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
):
r"""
Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
https://github.com/AminRezaei0x443/memory-efficient-attention

Args:
query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
numerical precision for computation
query_chunk_size (`int`, *optional*, defaults to 1024):
chunk size to divide query array value must divide query_length equally without remainder
key_chunk_size (`int`, *optional*, defaults to 4096):
chunk size to divide key and value array value must divide key_value_length equally without remainder

Returns:
(`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
"""
num_q, num_heads, q_features = query.shape[-3:]

def chunk_scanner(chunk_idx, _):
# julienne query array
query_chunk = jax.lax.dynamic_slice(
operand=query,
start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
)

return (
chunk_idx + query_chunk_size, # unused ignore it
_query_chunk_attention(
query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
),
)

_, res = jax.lax.scan(
f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter
)

return jnp.concatenate(res, axis=-3) # fuse the chunked result back


class FlaxAttention(nn.Module):
r"""
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
Expand All @@ -29,6 +129,8 @@ class FlaxAttention(nn.Module):
Hidden states dimension inside each head
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`

Expand All @@ -37,6 +139,7 @@ class FlaxAttention(nn.Module):
heads: int = 8
dim_head: int = 64
dropout: float = 0.0
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32

def setup(self):
Expand Down Expand Up @@ -77,13 +180,38 @@ def __call__(self, hidden_states, context=None, deterministic=True):
key_states = self.reshape_heads_to_batch_dim(key_proj)
value_states = self.reshape_heads_to_batch_dim(value_proj)

# compute attentions
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
attention_scores = attention_scores * self.scale
attention_probs = nn.softmax(attention_scores, axis=2)
if self.use_memory_efficient_attention:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

query_states = query_states.transpose(1, 0, 2)
key_states = key_states.transpose(1, 0, 2)
value_states = value_states.transpose(1, 0, 2)

# this if statement create a chunk size for each layer of the unet
# the chunk size is equal to the query_length dimension of the deepest layer of the unet

flatten_latent_dim = query_states.shape[-3]
if flatten_latent_dim % 64 == 0:
query_chunk_size = int(flatten_latent_dim / 64)
elif flatten_latent_dim % 16 == 0:
query_chunk_size = int(flatten_latent_dim / 16)
elif flatten_latent_dim % 4 == 0:
query_chunk_size = int(flatten_latent_dim / 4)
else:
query_chunk_size = int(flatten_latent_dim)

hidden_states = jax_memory_efficient_attention(
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
)

hidden_states = hidden_states.transpose(1, 0, 2)
else:
# compute attentions
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
attention_scores = attention_scores * self.scale
attention_probs = nn.softmax(attention_scores, axis=2)

# attend to values
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)

# attend to values
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
hidden_states = self.proj_attn(hidden_states)
return hidden_states
Expand All @@ -108,19 +236,26 @@ class FlaxBasicTransformerBlock(nn.Module):
Whether to only apply cross attention.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

enable memory efficient attention https://arxiv.org/abs/2112.05682
"""
dim: int
n_heads: int
d_head: int
dropout: float = 0.0
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
use_memory_efficient_attention: bool = False

def setup(self):
# self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn1 = FlaxAttention(
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
)
# cross attention
self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn2 = FlaxAttention(
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
)
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
Expand Down Expand Up @@ -169,6 +304,8 @@ class FlaxTransformer2DModel(nn.Module):
only_cross_attention (`bool`, defaults to `False`): tbd
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
"""
in_channels: int
n_heads: int
Expand All @@ -178,6 +315,7 @@ class FlaxTransformer2DModel(nn.Module):
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
use_memory_efficient_attention: bool = False

def setup(self):
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
Expand All @@ -202,6 +340,7 @@ def setup(self):
dropout=self.dropout,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
use_memory_efficient_attention=self.use_memory_efficient_attention,
)
for _ in range(self.depth)
]
Expand Down
12 changes: 12 additions & 0 deletions src/diffusers/models/unet_2d_blocks_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
Number of attention heads of each spatial transformer block
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
Expand All @@ -48,6 +50,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
add_downsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32

def setup(self):
Expand All @@ -72,6 +75,7 @@ def setup(self):
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
)
attentions.append(attn_block)
Expand Down Expand Up @@ -172,6 +176,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
Number of attention heads of each spatial transformer block
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add upsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
Expand All @@ -184,6 +190,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
add_upsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32

def setup(self):
Expand All @@ -209,6 +216,7 @@ def setup(self):
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
)
attentions.append(attn_block)
Expand Down Expand Up @@ -311,6 +319,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
Expand All @@ -319,6 +329,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
num_layers: int = 1
attn_num_head_channels: int = 1
use_linear_projection: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32

def setup(self):
Expand All @@ -341,6 +352,7 @@ def setup(self):
d_head=self.in_channels // self.attn_num_head_channels,
depth=1,
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
)
attentions.append(attn_block)
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/models/unet_2d_condition_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682

"""

Expand All @@ -111,6 +113,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
dtype: jnp.dtype = jnp.float32
flip_sin_to_cos: bool = True
freq_shift: int = 0
use_memory_efficient_attention: bool = False

def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
# init input tensors
Expand Down Expand Up @@ -169,6 +172,7 @@ def setup(self):
add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
)
else:
Expand All @@ -190,6 +194,7 @@ def setup(self):
dropout=self.dropout,
attn_num_head_channels=attention_head_dim[-1],
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
)

Expand Down Expand Up @@ -217,6 +222,7 @@ def setup(self):
dropout=self.dropout,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
)
else:
Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/pipelines/pipeline_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
from_pt = kwargs.pop("from_pt", False)
use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to enable this at pipeline load time. It's not straightforward to replicate the recursive logic we created for xformers because Flax submodules are not available unless you are in apply, so I didn't find a way to make functions to enable/disable the setting on demand. Instead, the configuration is set up when the pipeline is instantiated.

Another alternative would have been to pass use_memory_efficient_attention as an additional argument to generate so it's applied on a per-inference basis. I thought making it constant for the pipeline made sense for the Flax case.

Any other thoughts about this @patrickvonplaten, @yiyixuxu, @williamberman?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion about this - I think passing it as an argument is probably easier to implement, but what you already did is good:)

dtype = kwargs.pop("dtype", None)

# 1. Download the checkpoints and configs
Expand Down Expand Up @@ -451,7 +452,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
loaded_sub_model = cached_folder

if issubclass(class_obj, FlaxModelMixin):
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
loaded_sub_model, loaded_params = load_method(
loadable_folder,
from_pt=from_pt,
use_memory_efficient_attention=use_memory_efficient_attention,
dtype=dtype,
)
params[name] = loaded_params
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
if from_pt:
Expand Down
Loading