Skip to content

Commit 0313489

Browse files
pcuencamuhammad_hanifMuhHanif
authored andcommitted
Flax memory efficient attention (huggingface#2889)
* add use_memory_efficient params placeholder * test * add memory efficient attention jax * add memory efficient attention jax * newline * forgot dot * Rename use_memory_efficient * Keep dtype last. * Actually use key_chunk_size * Rename symbol * Apply style * Rename use_memory_efficient * Keep dtype last * Pass `use_memory_efficient_attention` in `from_pretrained` * Move JAX memory efficient attention to attention_flax. * Simple test. * style --------- Co-authored-by: muhammad_hanif <[email protected]> Co-authored-by: MuhHanif <[email protected]>
1 parent 3748204 commit 0313489

File tree

5 files changed

+216
-9
lines changed

5 files changed

+216
-9
lines changed

src/diffusers/models/attention_flax.py

+147-8
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,110 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import functools
16+
import math
17+
1518
import flax.linen as nn
19+
import jax
1620
import jax.numpy as jnp
1721

1822

23+
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
24+
"""Multi-head dot product attention with a limited number of queries."""
25+
num_kv, num_heads, k_features = key.shape[-3:]
26+
v_features = value.shape[-1]
27+
key_chunk_size = min(key_chunk_size, num_kv)
28+
query = query / jnp.sqrt(k_features)
29+
30+
@functools.partial(jax.checkpoint, prevent_cse=False)
31+
def summarize_chunk(query, key, value):
32+
attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
33+
34+
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
35+
max_score = jax.lax.stop_gradient(max_score)
36+
exp_weights = jnp.exp(attn_weights - max_score)
37+
38+
exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
39+
max_score = jnp.einsum("...qhk->...qh", max_score)
40+
41+
return (exp_values, exp_weights.sum(axis=-1), max_score)
42+
43+
def chunk_scanner(chunk_idx):
44+
# julienne key array
45+
key_chunk = jax.lax.dynamic_slice(
46+
operand=key,
47+
start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
48+
slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
49+
)
50+
51+
# julienne value array
52+
value_chunk = jax.lax.dynamic_slice(
53+
operand=value,
54+
start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
55+
slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
56+
)
57+
58+
return summarize_chunk(query, key_chunk, value_chunk)
59+
60+
chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
61+
62+
global_max = jnp.max(chunk_max, axis=0, keepdims=True)
63+
max_diffs = jnp.exp(chunk_max - global_max)
64+
65+
chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
66+
chunk_weights *= max_diffs
67+
68+
all_values = chunk_values.sum(axis=0)
69+
all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
70+
71+
return all_values / all_weights
72+
73+
74+
def jax_memory_efficient_attention(
75+
query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
76+
):
77+
r"""
78+
Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
79+
https://github.com/AminRezaei0x443/memory-efficient-attention
80+
81+
Args:
82+
query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
83+
key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
84+
value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
85+
precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
86+
numerical precision for computation
87+
query_chunk_size (`int`, *optional*, defaults to 1024):
88+
chunk size to divide query array value must divide query_length equally without remainder
89+
key_chunk_size (`int`, *optional*, defaults to 4096):
90+
chunk size to divide key and value array value must divide key_value_length equally without remainder
91+
92+
Returns:
93+
(`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
94+
"""
95+
num_q, num_heads, q_features = query.shape[-3:]
96+
97+
def chunk_scanner(chunk_idx, _):
98+
# julienne query array
99+
query_chunk = jax.lax.dynamic_slice(
100+
operand=query,
101+
start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
102+
slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
103+
)
104+
105+
return (
106+
chunk_idx + query_chunk_size, # unused ignore it
107+
_query_chunk_attention(
108+
query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
109+
),
110+
)
111+
112+
_, res = jax.lax.scan(
113+
f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter
114+
)
115+
116+
return jnp.concatenate(res, axis=-3) # fuse the chunked result back
117+
118+
19119
class FlaxAttention(nn.Module):
20120
r"""
21121
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
@@ -29,6 +129,8 @@ class FlaxAttention(nn.Module):
29129
Hidden states dimension inside each head
30130
dropout (:obj:`float`, *optional*, defaults to 0.0):
31131
Dropout rate
132+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
133+
enable memory efficient attention https://arxiv.org/abs/2112.05682
32134
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
33135
Parameters `dtype`
34136
@@ -37,6 +139,7 @@ class FlaxAttention(nn.Module):
37139
heads: int = 8
38140
dim_head: int = 64
39141
dropout: float = 0.0
142+
use_memory_efficient_attention: bool = False
40143
dtype: jnp.dtype = jnp.float32
41144

42145
def setup(self):
@@ -77,13 +180,38 @@ def __call__(self, hidden_states, context=None, deterministic=True):
77180
key_states = self.reshape_heads_to_batch_dim(key_proj)
78181
value_states = self.reshape_heads_to_batch_dim(value_proj)
79182

80-
# compute attentions
81-
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
82-
attention_scores = attention_scores * self.scale
83-
attention_probs = nn.softmax(attention_scores, axis=2)
183+
if self.use_memory_efficient_attention:
184+
query_states = query_states.transpose(1, 0, 2)
185+
key_states = key_states.transpose(1, 0, 2)
186+
value_states = value_states.transpose(1, 0, 2)
187+
188+
# this if statement create a chunk size for each layer of the unet
189+
# the chunk size is equal to the query_length dimension of the deepest layer of the unet
190+
191+
flatten_latent_dim = query_states.shape[-3]
192+
if flatten_latent_dim % 64 == 0:
193+
query_chunk_size = int(flatten_latent_dim / 64)
194+
elif flatten_latent_dim % 16 == 0:
195+
query_chunk_size = int(flatten_latent_dim / 16)
196+
elif flatten_latent_dim % 4 == 0:
197+
query_chunk_size = int(flatten_latent_dim / 4)
198+
else:
199+
query_chunk_size = int(flatten_latent_dim)
200+
201+
hidden_states = jax_memory_efficient_attention(
202+
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
203+
)
204+
205+
hidden_states = hidden_states.transpose(1, 0, 2)
206+
else:
207+
# compute attentions
208+
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
209+
attention_scores = attention_scores * self.scale
210+
attention_probs = nn.softmax(attention_scores, axis=2)
211+
212+
# attend to values
213+
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
84214

85-
# attend to values
86-
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
87215
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
88216
hidden_states = self.proj_attn(hidden_states)
89217
return hidden_states
@@ -108,19 +236,26 @@ class FlaxBasicTransformerBlock(nn.Module):
108236
Whether to only apply cross attention.
109237
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
110238
Parameters `dtype`
239+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
240+
enable memory efficient attention https://arxiv.org/abs/2112.05682
111241
"""
112242
dim: int
113243
n_heads: int
114244
d_head: int
115245
dropout: float = 0.0
116246
only_cross_attention: bool = False
117247
dtype: jnp.dtype = jnp.float32
248+
use_memory_efficient_attention: bool = False
118249

119250
def setup(self):
120251
# self attention (or cross_attention if only_cross_attention is True)
121-
self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
252+
self.attn1 = FlaxAttention(
253+
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
254+
)
122255
# cross attention
123-
self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
256+
self.attn2 = FlaxAttention(
257+
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
258+
)
124259
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
125260
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
126261
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
@@ -169,6 +304,8 @@ class FlaxTransformer2DModel(nn.Module):
169304
only_cross_attention (`bool`, defaults to `False`): tbd
170305
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
171306
Parameters `dtype`
307+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
308+
enable memory efficient attention https://arxiv.org/abs/2112.05682
172309
"""
173310
in_channels: int
174311
n_heads: int
@@ -178,6 +315,7 @@ class FlaxTransformer2DModel(nn.Module):
178315
use_linear_projection: bool = False
179316
only_cross_attention: bool = False
180317
dtype: jnp.dtype = jnp.float32
318+
use_memory_efficient_attention: bool = False
181319

182320
def setup(self):
183321
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
@@ -202,6 +340,7 @@ def setup(self):
202340
dropout=self.dropout,
203341
only_cross_attention=self.only_cross_attention,
204342
dtype=self.dtype,
343+
use_memory_efficient_attention=self.use_memory_efficient_attention,
205344
)
206345
for _ in range(self.depth)
207346
]

src/diffusers/models/unet_2d_blocks_flax.py

+12
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
3737
Number of attention heads of each spatial transformer block
3838
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
3939
Whether to add downsampling layer before each final output
40+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
41+
enable memory efficient attention https://arxiv.org/abs/2112.05682
4042
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
4143
Parameters `dtype`
4244
"""
@@ -48,6 +50,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
4850
add_downsample: bool = True
4951
use_linear_projection: bool = False
5052
only_cross_attention: bool = False
53+
use_memory_efficient_attention: bool = False
5154
dtype: jnp.dtype = jnp.float32
5255

5356
def setup(self):
@@ -72,6 +75,7 @@ def setup(self):
7275
depth=1,
7376
use_linear_projection=self.use_linear_projection,
7477
only_cross_attention=self.only_cross_attention,
78+
use_memory_efficient_attention=self.use_memory_efficient_attention,
7579
dtype=self.dtype,
7680
)
7781
attentions.append(attn_block)
@@ -172,6 +176,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
172176
Number of attention heads of each spatial transformer block
173177
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
174178
Whether to add upsampling layer before each final output
179+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
180+
enable memory efficient attention https://arxiv.org/abs/2112.05682
175181
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
176182
Parameters `dtype`
177183
"""
@@ -184,6 +190,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
184190
add_upsample: bool = True
185191
use_linear_projection: bool = False
186192
only_cross_attention: bool = False
193+
use_memory_efficient_attention: bool = False
187194
dtype: jnp.dtype = jnp.float32
188195

189196
def setup(self):
@@ -209,6 +216,7 @@ def setup(self):
209216
depth=1,
210217
use_linear_projection=self.use_linear_projection,
211218
only_cross_attention=self.only_cross_attention,
219+
use_memory_efficient_attention=self.use_memory_efficient_attention,
212220
dtype=self.dtype,
213221
)
214222
attentions.append(attn_block)
@@ -311,6 +319,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
311319
Number of attention blocks layers
312320
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
313321
Number of attention heads of each spatial transformer block
322+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
323+
enable memory efficient attention https://arxiv.org/abs/2112.05682
314324
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
315325
Parameters `dtype`
316326
"""
@@ -319,6 +329,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
319329
num_layers: int = 1
320330
attn_num_head_channels: int = 1
321331
use_linear_projection: bool = False
332+
use_memory_efficient_attention: bool = False
322333
dtype: jnp.dtype = jnp.float32
323334

324335
def setup(self):
@@ -341,6 +352,7 @@ def setup(self):
341352
d_head=self.in_channels // self.attn_num_head_channels,
342353
depth=1,
343354
use_linear_projection=self.use_linear_projection,
355+
use_memory_efficient_attention=self.use_memory_efficient_attention,
344356
dtype=self.dtype,
345357
)
346358
attentions.append(attn_block)

src/diffusers/models/unet_2d_condition_flax.py

+6
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
8888
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
8989
Whether to flip the sin to cos in the time embedding.
9090
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
91+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
92+
enable memory efficient attention https://arxiv.org/abs/2112.05682
9193
9294
"""
9395

@@ -111,6 +113,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
111113
dtype: jnp.dtype = jnp.float32
112114
flip_sin_to_cos: bool = True
113115
freq_shift: int = 0
116+
use_memory_efficient_attention: bool = False
114117

115118
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
116119
# init input tensors
@@ -169,6 +172,7 @@ def setup(self):
169172
add_downsample=not is_final_block,
170173
use_linear_projection=self.use_linear_projection,
171174
only_cross_attention=only_cross_attention[i],
175+
use_memory_efficient_attention=self.use_memory_efficient_attention,
172176
dtype=self.dtype,
173177
)
174178
else:
@@ -190,6 +194,7 @@ def setup(self):
190194
dropout=self.dropout,
191195
attn_num_head_channels=attention_head_dim[-1],
192196
use_linear_projection=self.use_linear_projection,
197+
use_memory_efficient_attention=self.use_memory_efficient_attention,
193198
dtype=self.dtype,
194199
)
195200

@@ -217,6 +222,7 @@ def setup(self):
217222
dropout=self.dropout,
218223
use_linear_projection=self.use_linear_projection,
219224
only_cross_attention=only_cross_attention[i],
225+
use_memory_efficient_attention=self.use_memory_efficient_attention,
220226
dtype=self.dtype,
221227
)
222228
else:

src/diffusers/pipelines/pipeline_flax_utils.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
296296
use_auth_token = kwargs.pop("use_auth_token", None)
297297
revision = kwargs.pop("revision", None)
298298
from_pt = kwargs.pop("from_pt", False)
299+
use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False)
299300
dtype = kwargs.pop("dtype", None)
300301

301302
# 1. Download the checkpoints and configs
@@ -451,7 +452,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
451452
loaded_sub_model = cached_folder
452453

453454
if issubclass(class_obj, FlaxModelMixin):
454-
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
455+
loaded_sub_model, loaded_params = load_method(
456+
loadable_folder,
457+
from_pt=from_pt,
458+
use_memory_efficient_attention=use_memory_efficient_attention,
459+
dtype=dtype,
460+
)
455461
params[name] = loaded_params
456462
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
457463
if from_pt:

0 commit comments

Comments
 (0)