12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import functools
16
+ import math
17
+
15
18
import flax .linen as nn
19
+ import jax
16
20
import jax .numpy as jnp
17
21
18
22
23
+ def _query_chunk_attention (query , key , value , precision , key_chunk_size : int = 4096 ):
24
+ """Multi-head dot product attention with a limited number of queries."""
25
+ num_kv , num_heads , k_features = key .shape [- 3 :]
26
+ v_features = value .shape [- 1 ]
27
+ key_chunk_size = min (key_chunk_size , num_kv )
28
+ query = query / jnp .sqrt (k_features )
29
+
30
+ @functools .partial (jax .checkpoint , prevent_cse = False )
31
+ def summarize_chunk (query , key , value ):
32
+ attn_weights = jnp .einsum ("...qhd,...khd->...qhk" , query , key , precision = precision )
33
+
34
+ max_score = jnp .max (attn_weights , axis = - 1 , keepdims = True )
35
+ max_score = jax .lax .stop_gradient (max_score )
36
+ exp_weights = jnp .exp (attn_weights - max_score )
37
+
38
+ exp_values = jnp .einsum ("...vhf,...qhv->...qhf" , value , exp_weights , precision = precision )
39
+ max_score = jnp .einsum ("...qhk->...qh" , max_score )
40
+
41
+ return (exp_values , exp_weights .sum (axis = - 1 ), max_score )
42
+
43
+ def chunk_scanner (chunk_idx ):
44
+ # julienne key array
45
+ key_chunk = jax .lax .dynamic_slice (
46
+ operand = key ,
47
+ start_indices = [0 ] * (key .ndim - 3 ) + [chunk_idx , 0 , 0 ], # [...,k,h,d]
48
+ slice_sizes = list (key .shape [:- 3 ]) + [key_chunk_size , num_heads , k_features ], # [...,k,h,d]
49
+ )
50
+
51
+ # julienne value array
52
+ value_chunk = jax .lax .dynamic_slice (
53
+ operand = value ,
54
+ start_indices = [0 ] * (value .ndim - 3 ) + [chunk_idx , 0 , 0 ], # [...,v,h,d]
55
+ slice_sizes = list (value .shape [:- 3 ]) + [key_chunk_size , num_heads , v_features ], # [...,v,h,d]
56
+ )
57
+
58
+ return summarize_chunk (query , key_chunk , value_chunk )
59
+
60
+ chunk_values , chunk_weights , chunk_max = jax .lax .map (f = chunk_scanner , xs = jnp .arange (0 , num_kv , key_chunk_size ))
61
+
62
+ global_max = jnp .max (chunk_max , axis = 0 , keepdims = True )
63
+ max_diffs = jnp .exp (chunk_max - global_max )
64
+
65
+ chunk_values *= jnp .expand_dims (max_diffs , axis = - 1 )
66
+ chunk_weights *= max_diffs
67
+
68
+ all_values = chunk_values .sum (axis = 0 )
69
+ all_weights = jnp .expand_dims (chunk_weights , - 1 ).sum (axis = 0 )
70
+
71
+ return all_values / all_weights
72
+
73
+
74
+ def jax_memory_efficient_attention (
75
+ query , key , value , precision = jax .lax .Precision .HIGHEST , query_chunk_size : int = 1024 , key_chunk_size : int = 4096
76
+ ):
77
+ r"""
78
+ Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
79
+ https://github.com/AminRezaei0x443/memory-efficient-attention
80
+
81
+ Args:
82
+ query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
83
+ key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
84
+ value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
85
+ precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
86
+ numerical precision for computation
87
+ query_chunk_size (`int`, *optional*, defaults to 1024):
88
+ chunk size to divide query array value must divide query_length equally without remainder
89
+ key_chunk_size (`int`, *optional*, defaults to 4096):
90
+ chunk size to divide key and value array value must divide key_value_length equally without remainder
91
+
92
+ Returns:
93
+ (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
94
+ """
95
+ num_q , num_heads , q_features = query .shape [- 3 :]
96
+
97
+ def chunk_scanner (chunk_idx , _ ):
98
+ # julienne query array
99
+ query_chunk = jax .lax .dynamic_slice (
100
+ operand = query ,
101
+ start_indices = ([0 ] * (query .ndim - 3 )) + [chunk_idx , 0 , 0 ], # [...,q,h,d]
102
+ slice_sizes = list (query .shape [:- 3 ]) + [min (query_chunk_size , num_q ), num_heads , q_features ], # [...,q,h,d]
103
+ )
104
+
105
+ return (
106
+ chunk_idx + query_chunk_size , # unused ignore it
107
+ _query_chunk_attention (
108
+ query = query_chunk , key = key , value = value , precision = precision , key_chunk_size = key_chunk_size
109
+ ),
110
+ )
111
+
112
+ _ , res = jax .lax .scan (
113
+ f = chunk_scanner , init = 0 , xs = None , length = math .ceil (num_q / query_chunk_size ) # start counter # stop counter
114
+ )
115
+
116
+ return jnp .concatenate (res , axis = - 3 ) # fuse the chunked result back
117
+
118
+
19
119
class FlaxAttention (nn .Module ):
20
120
r"""
21
121
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
@@ -29,6 +129,8 @@ class FlaxAttention(nn.Module):
29
129
Hidden states dimension inside each head
30
130
dropout (:obj:`float`, *optional*, defaults to 0.0):
31
131
Dropout rate
132
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
133
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
32
134
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
33
135
Parameters `dtype`
34
136
@@ -37,6 +139,7 @@ class FlaxAttention(nn.Module):
37
139
heads : int = 8
38
140
dim_head : int = 64
39
141
dropout : float = 0.0
142
+ use_memory_efficient_attention : bool = False
40
143
dtype : jnp .dtype = jnp .float32
41
144
42
145
def setup (self ):
@@ -77,13 +180,38 @@ def __call__(self, hidden_states, context=None, deterministic=True):
77
180
key_states = self .reshape_heads_to_batch_dim (key_proj )
78
181
value_states = self .reshape_heads_to_batch_dim (value_proj )
79
182
80
- # compute attentions
81
- attention_scores = jnp .einsum ("b i d, b j d->b i j" , query_states , key_states )
82
- attention_scores = attention_scores * self .scale
83
- attention_probs = nn .softmax (attention_scores , axis = 2 )
183
+ if self .use_memory_efficient_attention :
184
+ query_states = query_states .transpose (1 , 0 , 2 )
185
+ key_states = key_states .transpose (1 , 0 , 2 )
186
+ value_states = value_states .transpose (1 , 0 , 2 )
187
+
188
+ # this if statement create a chunk size for each layer of the unet
189
+ # the chunk size is equal to the query_length dimension of the deepest layer of the unet
190
+
191
+ flatten_latent_dim = query_states .shape [- 3 ]
192
+ if flatten_latent_dim % 64 == 0 :
193
+ query_chunk_size = int (flatten_latent_dim / 64 )
194
+ elif flatten_latent_dim % 16 == 0 :
195
+ query_chunk_size = int (flatten_latent_dim / 16 )
196
+ elif flatten_latent_dim % 4 == 0 :
197
+ query_chunk_size = int (flatten_latent_dim / 4 )
198
+ else :
199
+ query_chunk_size = int (flatten_latent_dim )
200
+
201
+ hidden_states = jax_memory_efficient_attention (
202
+ query_states , key_states , value_states , query_chunk_size = query_chunk_size , key_chunk_size = 4096 * 4
203
+ )
204
+
205
+ hidden_states = hidden_states .transpose (1 , 0 , 2 )
206
+ else :
207
+ # compute attentions
208
+ attention_scores = jnp .einsum ("b i d, b j d->b i j" , query_states , key_states )
209
+ attention_scores = attention_scores * self .scale
210
+ attention_probs = nn .softmax (attention_scores , axis = 2 )
211
+
212
+ # attend to values
213
+ hidden_states = jnp .einsum ("b i j, b j d -> b i d" , attention_probs , value_states )
84
214
85
- # attend to values
86
- hidden_states = jnp .einsum ("b i j, b j d -> b i d" , attention_probs , value_states )
87
215
hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
88
216
hidden_states = self .proj_attn (hidden_states )
89
217
return hidden_states
@@ -108,19 +236,26 @@ class FlaxBasicTransformerBlock(nn.Module):
108
236
Whether to only apply cross attention.
109
237
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
110
238
Parameters `dtype`
239
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
240
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
111
241
"""
112
242
dim : int
113
243
n_heads : int
114
244
d_head : int
115
245
dropout : float = 0.0
116
246
only_cross_attention : bool = False
117
247
dtype : jnp .dtype = jnp .float32
248
+ use_memory_efficient_attention : bool = False
118
249
119
250
def setup (self ):
120
251
# self attention (or cross_attention if only_cross_attention is True)
121
- self .attn1 = FlaxAttention (self .dim , self .n_heads , self .d_head , self .dropout , dtype = self .dtype )
252
+ self .attn1 = FlaxAttention (
253
+ self .dim , self .n_heads , self .d_head , self .dropout , self .use_memory_efficient_attention , dtype = self .dtype
254
+ )
122
255
# cross attention
123
- self .attn2 = FlaxAttention (self .dim , self .n_heads , self .d_head , self .dropout , dtype = self .dtype )
256
+ self .attn2 = FlaxAttention (
257
+ self .dim , self .n_heads , self .d_head , self .dropout , self .use_memory_efficient_attention , dtype = self .dtype
258
+ )
124
259
self .ff = FlaxFeedForward (dim = self .dim , dropout = self .dropout , dtype = self .dtype )
125
260
self .norm1 = nn .LayerNorm (epsilon = 1e-5 , dtype = self .dtype )
126
261
self .norm2 = nn .LayerNorm (epsilon = 1e-5 , dtype = self .dtype )
@@ -169,6 +304,8 @@ class FlaxTransformer2DModel(nn.Module):
169
304
only_cross_attention (`bool`, defaults to `False`): tbd
170
305
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
171
306
Parameters `dtype`
307
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
308
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
172
309
"""
173
310
in_channels : int
174
311
n_heads : int
@@ -178,6 +315,7 @@ class FlaxTransformer2DModel(nn.Module):
178
315
use_linear_projection : bool = False
179
316
only_cross_attention : bool = False
180
317
dtype : jnp .dtype = jnp .float32
318
+ use_memory_efficient_attention : bool = False
181
319
182
320
def setup (self ):
183
321
self .norm = nn .GroupNorm (num_groups = 32 , epsilon = 1e-5 )
@@ -202,6 +340,7 @@ def setup(self):
202
340
dropout = self .dropout ,
203
341
only_cross_attention = self .only_cross_attention ,
204
342
dtype = self .dtype ,
343
+ use_memory_efficient_attention = self .use_memory_efficient_attention ,
205
344
)
206
345
for _ in range (self .depth )
207
346
]
0 commit comments