@@ -86,29 +86,24 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
86
86
print (f"The default implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
87
87
88
88
# Lets explore the speed of each of the 3 implementations
89
- from torch .backends . cuda import sdp_kernel , SDPBackend
89
+ from torch .nn . attention import SDPBackend , sdpa_kernel
90
90
91
- # Helpful arguments mapper
92
- backend_map = {
93
- SDPBackend .MATH : {"enable_math" : True , "enable_flash" : False , "enable_mem_efficient" : False },
94
- SDPBackend .FLASH_ATTENTION : {"enable_math" : False , "enable_flash" : True , "enable_mem_efficient" : False },
95
- SDPBackend .EFFICIENT_ATTENTION : {
96
- "enable_math" : False , "enable_flash" : False , "enable_mem_efficient" : True }
97
- }
98
91
99
- with sdp_kernel (** backend_map [SDPBackend .MATH ]):
100
- print (f"The math implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
92
+ with sdpa_kernel (SDPBackend .MATH ):
93
+ math_time = benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value )
94
+ print (f"The math implementation runs in { math_time :.3f} microseconds" )
101
95
102
-
103
- with sdp_kernel (** backend_map [SDPBackend .FLASH_ATTENTION ]):
96
+ with sdpa_kernel (SDPBackend .FLASH_ATTENTION ):
104
97
try :
105
- print (f"The flash attention implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
98
+ flash_time = benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value )
99
+ print (f"The flash attention implementation runs in { flash_time :.3f} microseconds" )
106
100
except RuntimeError :
107
101
print ("FlashAttention is not supported. See warnings for reasons." )
108
102
109
- with sdp_kernel ( ** backend_map [ SDPBackend .EFFICIENT_ATTENTION ] ):
103
+ with sdpa_kernel ( SDPBackend .EFFICIENT_ATTENTION ):
110
104
try :
111
- print (f"The memory efficient implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
105
+ efficient_time = benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value )
106
+ print (f"The memory efficient implementation runs in { efficient_time :.3f} microseconds" )
112
107
except RuntimeError :
113
108
print ("EfficientAttention is not supported. See warnings for reasons." )
114
109
@@ -239,7 +234,7 @@ def generate_rand_batch(
239
234
# Currently the fused implementations don't support ``NestedTensor`` for training
240
235
model .eval ()
241
236
242
- with sdp_kernel ( ** backend_map [ SDPBackend .FLASH_ATTENTION ] ):
237
+ with sdpa_kernel ( SDPBackend .FLASH_ATTENTION ):
243
238
try :
244
239
print (f"Random NT runs in { benchmark_torch_function_in_microseconds (model , random_nt ):.3f} microseconds" )
245
240
print (f"Random Dense runs in { benchmark_torch_function_in_microseconds (model , random_dense ):.3f} microseconds" )
@@ -328,14 +323,82 @@ def generate_rand_batch(
328
323
# the Shakespeare dataset.
329
324
#
330
325
326
+ ######################################################################
327
+ # Using SDPA with attn_bias subclasses`
328
+ # ==========================================
329
+ #
330
+ # As of PyTorch 2.3, we have added a new submodule that contains tensor subclasses.
331
+ # Designed to be used with ``torch.nn.functional.scaled_dot_product_attention``.
332
+ # The module is named ``torch.nn.attention.bias`` and contains the following two
333
+ # utilities for generating causal attention variants:
334
+ #
335
+ # - ``torch.nn.attention.bias.causal_upper_left``
336
+ # - ``torch.nn.attention.bias.causal_lower_right``
337
+ #
338
+ # .. note::
339
+ # The current argument ``is_causal`` in ``torch.nn.functional.scaled_dot_product_attention``
340
+ # is the same as using ``torch.nn.attention.bias.causal_upper_left``.
341
+ #
342
+
343
+ from torch .nn .attention .bias import causal_lower_right , causal_upper_left
344
+
345
+ batch_size = 32
346
+ sequence_length_q = 2
347
+ sequence_length_kv = 10
348
+ num_heads = 16
349
+ embed_dimension = 32
350
+
351
+ dtype = torch .float16
352
+
353
+ query = torch .rand (batch_size , num_heads , sequence_length_q , embed_dimension , device = device , dtype = dtype )
354
+ key = torch .rand (batch_size , num_heads , sequence_length_kv , embed_dimension , device = device , dtype = dtype )
355
+ value = torch .rand (batch_size , num_heads , sequence_length_kv , embed_dimension , device = device , dtype = dtype )
356
+
357
+ upper_left_bias = causal_upper_left (sequence_length_q , sequence_length_kv )
358
+ lower_right_bias = causal_lower_right (sequence_length_q , sequence_length_kv )
359
+
360
+ print (type (upper_left_bias ))
361
+ print (type (lower_right_bias ))
362
+
363
+ assert type (upper_left_bias ) == type (lower_right_bias )
364
+ assert issubclass (type (upper_left_bias ), torch .Tensor )
365
+
366
+ # As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``
367
+ # and subclass ``torch.Tensor``
368
+
369
+ # Lets see what these tensors look like
370
+ print (upper_left_bias )
371
+ print (lower_right_bias )
372
+
373
+ # Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.
374
+ # This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.
375
+ # Another way of thinking about this concept is that when you use upper left bias,
376
+ # the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,
377
+ # Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score
378
+ # between the 0th token in the query and the 0th token in the key.
379
+ # For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k
380
+ # (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k
381
+ # even if the sequence length of q and k are different.
382
+
383
+ # These objects are intended to be used with sdpa
384
+ out_upper_left = F .scaled_dot_product_attention (query , key , value , upper_left_bias )
385
+ out_lower_right = F .scaled_dot_product_attention (query , key , value , lower_right_bias )
386
+ out_is_causal = F .scaled_dot_product_attention (query , key , value , is_causal = True )
387
+
388
+ assert torch .allclose (out_upper_left , out_is_causal )
389
+ assert not torch .allclose (out_upper_left , out_lower_right )
390
+
391
+ # These attention biases should also be compatible with torch.compile
392
+ compiled_sdpa = torch .compile (F .scaled_dot_product_attention , fullgraph = True )
393
+ out_upper_left = compiled_sdpa (query , key , value , upper_left_bias )
331
394
332
395
######################################################################
333
396
# Conclusion
334
397
# ==========
335
398
#
336
399
# In this tutorial, we have demonstrated the basic usage of
337
400
# ``torch.nn.functional.scaled_dot_product_attention``. We have shown how
338
- # the ``sdp_kernel `` context manager can be used to assert a certain
401
+ # the ``sdpa_kernel `` context manager can be used to assert a certain
339
402
# implementation is used on GPU. As well, we built a simple
340
403
# ``CausalSelfAttention`` module that works with ``NestedTensor`` and is torch
341
404
# compilable. In the process we have shown how to the profiling tools can
0 commit comments