Skip to content

Commit 700b9d8

Browse files
drisspgsvekars
andauthored
Update the sdpa docs to use new context manager and attention_bias examples (#2831)
* update the sdpa docs to use new context manager and attention_bias example --------- Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent f6024bf commit 700b9d8

File tree

1 file changed

+80
-17
lines changed

1 file changed

+80
-17
lines changed

intermediate_source/scaled_dot_product_attention_tutorial.py

Lines changed: 80 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -86,29 +86,24 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
8686
print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
8787

8888
# Lets explore the speed of each of the 3 implementations
89-
from torch.backends.cuda import sdp_kernel, SDPBackend
89+
from torch.nn.attention import SDPBackend, sdpa_kernel
9090

91-
# Helpful arguments mapper
92-
backend_map = {
93-
SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
94-
SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
95-
SDPBackend.EFFICIENT_ATTENTION: {
96-
"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
97-
}
9891

99-
with sdp_kernel(**backend_map[SDPBackend.MATH]):
100-
print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
92+
with sdpa_kernel(SDPBackend.MATH):
93+
math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
94+
print(f"The math implementation runs in {math_time:.3f} microseconds")
10195

102-
103-
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
96+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
10497
try:
105-
print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
98+
flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
99+
print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
106100
except RuntimeError:
107101
print("FlashAttention is not supported. See warnings for reasons.")
108102

109-
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
103+
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
110104
try:
111-
print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
105+
efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
106+
print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
112107
except RuntimeError:
113108
print("EfficientAttention is not supported. See warnings for reasons.")
114109

@@ -239,7 +234,7 @@ def generate_rand_batch(
239234
# Currently the fused implementations don't support ``NestedTensor`` for training
240235
model.eval()
241236

242-
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
237+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
243238
try:
244239
print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
245240
print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
@@ -328,14 +323,82 @@ def generate_rand_batch(
328323
# the Shakespeare dataset.
329324
#
330325

326+
######################################################################
327+
# Using SDPA with attn_bias subclasses`
328+
# ==========================================
329+
#
330+
# As of PyTorch 2.3, we have added a new submodule that contains tensor subclasses.
331+
# Designed to be used with ``torch.nn.functional.scaled_dot_product_attention``.
332+
# The module is named ``torch.nn.attention.bias`` and contains the following two
333+
# utilities for generating causal attention variants:
334+
#
335+
# - ``torch.nn.attention.bias.causal_upper_left``
336+
# - ``torch.nn.attention.bias.causal_lower_right``
337+
#
338+
# .. note::
339+
# The current argument ``is_causal`` in ``torch.nn.functional.scaled_dot_product_attention``
340+
# is the same as using ``torch.nn.attention.bias.causal_upper_left``.
341+
#
342+
343+
from torch.nn.attention.bias import causal_lower_right, causal_upper_left
344+
345+
batch_size = 32
346+
sequence_length_q = 2
347+
sequence_length_kv = 10
348+
num_heads = 16
349+
embed_dimension = 32
350+
351+
dtype = torch.float16
352+
353+
query = torch.rand(batch_size, num_heads, sequence_length_q, embed_dimension, device=device, dtype=dtype)
354+
key = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
355+
value = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
356+
357+
upper_left_bias = causal_upper_left(sequence_length_q, sequence_length_kv)
358+
lower_right_bias = causal_lower_right(sequence_length_q, sequence_length_kv)
359+
360+
print(type(upper_left_bias))
361+
print(type(lower_right_bias))
362+
363+
assert type(upper_left_bias) == type(lower_right_bias)
364+
assert issubclass(type(upper_left_bias), torch.Tensor)
365+
366+
# As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``
367+
# and subclass ``torch.Tensor``
368+
369+
# Lets see what these tensors look like
370+
print(upper_left_bias)
371+
print(lower_right_bias)
372+
373+
# Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.
374+
# This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.
375+
# Another way of thinking about this concept is that when you use upper left bias,
376+
# the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,
377+
# Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score
378+
# between the 0th token in the query and the 0th token in the key.
379+
# For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k
380+
# (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k
381+
# even if the sequence length of q and k are different.
382+
383+
# These objects are intended to be used with sdpa
384+
out_upper_left = F.scaled_dot_product_attention(query, key, value, upper_left_bias)
385+
out_lower_right = F.scaled_dot_product_attention(query, key, value, lower_right_bias)
386+
out_is_causal = F.scaled_dot_product_attention(query, key, value, is_causal=True)
387+
388+
assert torch.allclose(out_upper_left, out_is_causal)
389+
assert not torch.allclose(out_upper_left, out_lower_right)
390+
391+
# These attention biases should also be compatible with torch.compile
392+
compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
393+
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)
331394

332395
######################################################################
333396
# Conclusion
334397
# ==========
335398
#
336399
# In this tutorial, we have demonstrated the basic usage of
337400
# ``torch.nn.functional.scaled_dot_product_attention``. We have shown how
338-
# the ``sdp_kernel`` context manager can be used to assert a certain
401+
# the ``sdpa_kernel`` context manager can be used to assert a certain
339402
# implementation is used on GPU. As well, we built a simple
340403
# ``CausalSelfAttention`` module that works with ``NestedTensor`` and is torch
341404
# compilable. In the process we have shown how to the profiling tools can

0 commit comments

Comments
 (0)