Skip to content

Commit 1c3d9e8

Browse files
authored
Merge branch 'main' into svekars-patch-22
2 parents 21849f9 + c5c0a9a commit 1c3d9e8

File tree

2 files changed

+118
-52
lines changed

2 files changed

+118
-52
lines changed

intermediate_source/scaled_dot_product_attention_tutorial.py

Lines changed: 80 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -86,29 +86,24 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
8686
print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
8787

8888
# Lets explore the speed of each of the 3 implementations
89-
from torch.backends.cuda import sdp_kernel, SDPBackend
89+
from torch.nn.attention import SDPBackend, sdpa_kernel
9090

91-
# Helpful arguments mapper
92-
backend_map = {
93-
SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
94-
SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
95-
SDPBackend.EFFICIENT_ATTENTION: {
96-
"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
97-
}
9891

99-
with sdp_kernel(**backend_map[SDPBackend.MATH]):
100-
print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
92+
with sdpa_kernel(SDPBackend.MATH):
93+
math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
94+
print(f"The math implementation runs in {math_time:.3f} microseconds")
10195

102-
103-
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
96+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
10497
try:
105-
print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
98+
flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
99+
print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
106100
except RuntimeError:
107101
print("FlashAttention is not supported. See warnings for reasons.")
108102

109-
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
103+
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
110104
try:
111-
print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
105+
efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
106+
print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
112107
except RuntimeError:
113108
print("EfficientAttention is not supported. See warnings for reasons.")
114109

@@ -239,7 +234,7 @@ def generate_rand_batch(
239234
# Currently the fused implementations don't support ``NestedTensor`` for training
240235
model.eval()
241236

242-
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
237+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
243238
try:
244239
print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
245240
print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
@@ -328,14 +323,82 @@ def generate_rand_batch(
328323
# the Shakespeare dataset.
329324
#
330325

326+
######################################################################
327+
# Using SDPA with attn_bias subclasses`
328+
# ==========================================
329+
#
330+
# As of PyTorch 2.3, we have added a new submodule that contains tensor subclasses.
331+
# Designed to be used with ``torch.nn.functional.scaled_dot_product_attention``.
332+
# The module is named ``torch.nn.attention.bias`` and contains the following two
333+
# utilities for generating causal attention variants:
334+
#
335+
# - ``torch.nn.attention.bias.causal_upper_left``
336+
# - ``torch.nn.attention.bias.causal_lower_right``
337+
#
338+
# .. note::
339+
# The current argument ``is_causal`` in ``torch.nn.functional.scaled_dot_product_attention``
340+
# is the same as using ``torch.nn.attention.bias.causal_upper_left``.
341+
#
342+
343+
from torch.nn.attention.bias import causal_lower_right, causal_upper_left
344+
345+
batch_size = 32
346+
sequence_length_q = 2
347+
sequence_length_kv = 10
348+
num_heads = 16
349+
embed_dimension = 32
350+
351+
dtype = torch.float16
352+
353+
query = torch.rand(batch_size, num_heads, sequence_length_q, embed_dimension, device=device, dtype=dtype)
354+
key = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
355+
value = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
356+
357+
upper_left_bias = causal_upper_left(sequence_length_q, sequence_length_kv)
358+
lower_right_bias = causal_lower_right(sequence_length_q, sequence_length_kv)
359+
360+
print(type(upper_left_bias))
361+
print(type(lower_right_bias))
362+
363+
assert type(upper_left_bias) == type(lower_right_bias)
364+
assert issubclass(type(upper_left_bias), torch.Tensor)
365+
366+
# As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``
367+
# and subclass ``torch.Tensor``
368+
369+
# Lets see what these tensors look like
370+
print(upper_left_bias)
371+
print(lower_right_bias)
372+
373+
# Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.
374+
# This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.
375+
# Another way of thinking about this concept is that when you use upper left bias,
376+
# the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,
377+
# Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score
378+
# between the 0th token in the query and the 0th token in the key.
379+
# For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k
380+
# (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k
381+
# even if the sequence length of q and k are different.
382+
383+
# These objects are intended to be used with sdpa
384+
out_upper_left = F.scaled_dot_product_attention(query, key, value, upper_left_bias)
385+
out_lower_right = F.scaled_dot_product_attention(query, key, value, lower_right_bias)
386+
out_is_causal = F.scaled_dot_product_attention(query, key, value, is_causal=True)
387+
388+
assert torch.allclose(out_upper_left, out_is_causal)
389+
assert not torch.allclose(out_upper_left, out_lower_right)
390+
391+
# These attention biases should also be compatible with torch.compile
392+
compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
393+
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)
331394

332395
######################################################################
333396
# Conclusion
334397
# ==========
335398
#
336399
# In this tutorial, we have demonstrated the basic usage of
337400
# ``torch.nn.functional.scaled_dot_product_attention``. We have shown how
338-
# the ``sdp_kernel`` context manager can be used to assert a certain
401+
# the ``sdpa_kernel`` context manager can be used to assert a certain
339402
# implementation is used on GPU. As well, we built a simple
340403
# ``CausalSelfAttention`` module that works with ``NestedTensor`` and is torch
341404
# compilable. In the process we have shown how to the profiling tools can

recipes_source/distributed_checkpoint_recipe.rst

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Getting Started with Distributed Checkpoint (DCP)
22
=====================================================
33

4-
**Author**: `Iris Zhang <https://github.com/wz337>`__, `Rodrigo Kumpera <https://github.com/kumpera>`__, `Chien-Chin Huang <https://github.com/fegin>`__
4+
**Author**: `Iris Zhang <https://github.com/wz337>`__, `Rodrigo Kumpera <https://github.com/kumpera>`__, `Chien-Chin Huang <https://github.com/fegin>`__, `Lucas Pasqualin <https://github.com/lucasllc>`__
55

66
.. note::
77
|edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_checkpoint_recipe.rst>`__.
@@ -22,8 +22,12 @@ In this tutorial, we show how to use DCP APIs with a simple FSDP wrapped model.
2222
How DCP works
2323
--------------
2424

25-
:func:`torch.distributed.checkpoint` enables saving and loading models from multiple ranks in parallel.
26-
In addition, checkpointing automatically handles fully-qualified-name (FQN) mappings across models and optimizers, enabling load-time resharding across differing cluster topologies.
25+
:func:`torch.distributed.checkpoint` enables saving and loading models from multiple ranks in parallel. You can use this module to save on any number of ranks in parallel,
26+
and then re-shard across differing cluster topologies at load time.
27+
28+
Addditionally, through the use of modules in :func:`torch.distributed.checkpoint.state_dict`,
29+
DCP offers support for gracefully handling ``state_dict`` generation and loading in distributed settings.
30+
This includes managing fully-qualified-name (FQN) mappings across models and optimizers, and setting default parameters for PyTorch provided parallelisms.
2731

2832
DCP is different from :func:`torch.save` and :func:`torch.load` in a few significant ways:
2933

@@ -42,19 +46,20 @@ Here we use a toy model wrapped with FSDP for demonstration purposes. Similarly,
4246
Saving
4347
~~~~~~
4448

45-
Now, lets create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it.
49+
Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it.
4650

4751
.. code-block:: python
4852
4953
import os
5054
5155
import torch
5256
import torch.distributed as dist
53-
import torch.distributed.checkpoint as DCP
57+
import torch.distributed.checkpoint as dcp
5458
import torch.multiprocessing as mp
5559
import torch.nn as nn
5660
5761
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
62+
from torch.distributed.checkpoint.state_dict import get_state_dict
5863
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
5964
6065
CHECKPOINT_DIR = "checkpoint"
@@ -99,20 +104,14 @@ Now, let’s create a toy module, wrap it with FSDP, feed it with some dummy inp
99104
model(torch.rand(8, 16, device="cuda")).sum().backward()
100105
optimizer.step()
101106
102-
# set FSDP StateDictType to SHARDED_STATE_DICT so we can use DCP to checkpoint sharded model state dict
103-
# note that we do not support FSDP StateDictType.LOCAL_STATE_DICT
104-
FSDP.set_state_dict_type(
105-
model,
106-
StateDictType.SHARDED_STATE_DICT,
107-
)
107+
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
108+
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
108109
state_dict = {
109-
"model": model.state_dict(),
110+
"model": model_state_dict,
111+
"optimizer": optimizer_state_dict
110112
}
113+
dcp.save(state_dict,checkpoint_id=CHECKPOINT_DIR)
111114
112-
DCP.save_state_dict(
113-
state_dict=state_dict,
114-
storage_writer=DCP.FileSystemWriter(CHECKPOINT_DIR),
115-
)
116115
117116
cleanup()
118117
@@ -152,12 +151,12 @@ The reason that we need the ``state_dict`` prior to loading is:
152151
153152
import torch
154153
import torch.distributed as dist
155-
import torch.distributed.checkpoint as DCP
154+
import torch.distributed.checkpoint as dcp
155+
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
156156
import torch.multiprocessing as mp
157157
import torch.nn as nn
158158
159159
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
160-
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
161160
162161
CHECKPOINT_DIR = "checkpoint"
163162
@@ -194,21 +193,23 @@ The reason that we need the ``state_dict`` prior to loading is:
194193
model = ToyModel().to(rank)
195194
model = FSDP(model)
196195
197-
FSDP.set_state_dict_type(
198-
model,
199-
StateDictType.SHARDED_STATE_DICT,
200-
)
201-
# different from ``torch.load()``, DCP requires model state_dict prior to loading to get
202-
# the allocated storage and sharding information.
196+
# generates the state dict we will load into
197+
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
203198
state_dict = {
204-
"model": model.state_dict(),
199+
"model": model_state_dict,
200+
"optimizer": optimizer_state_dict
205201
}
206-
207-
DCP.load_state_dict(
202+
dcp.load(
208203
state_dict=state_dict,
209-
storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
204+
checkpoint_id=CHECKPOINT_DIR,
205+
)
206+
# sets our state dicts on the model and optimizer, now that we've loaded
207+
set_state_dict(
208+
model,
209+
optimizer,
210+
model_state_dict=model_state_dict,
211+
optim_state_dict=optimizer_state_dict
210212
)
211-
model.load_state_dict(state_dict["model"])
212213
213214
cleanup()
214215
@@ -224,7 +225,8 @@ The reason that we need the ``state_dict`` prior to loading is:
224225
)
225226
226227
If you would like to load the saved checkpoint into a non-FSDP wrapped model in a non-distributed setup, perhaps for inference, you can also do that with DCP.
227-
By default, DCP saves and loads a distributed ``state_dict`` in Single Program Multiple Data(SPMD) style. To load without a distributed setup, please set ``no_dist`` to ``True`` when loading with DCP.
228+
By default, DCP saves and loads a distributed ``state_dict`` in Single Program Multiple Data(SPMD) style. However if no process group is initialized, DCP infers
229+
the intent is to save or load in "non-distributed" style, meaning entirely in the current process.
228230

229231
.. note::
230232
Distributed checkpoint support for Multi-Program Multi-Data is still under development.
@@ -259,11 +261,10 @@ By default, DCP saves and loads a distributed ``state_dict`` in Single Program M
259261
"model": model.state_dict(),
260262
}
261263
262-
# turn no_dist to be true to load in non-distributed setting
263-
DCP.load_state_dict(
264+
# since no progress group is initialized, DCP will disable any collectives.
265+
dcp.load(
264266
state_dict=state_dict,
265-
storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
266-
no_dist=True,
267+
checkpoint_id=CHECKPOINT_DIR,
267268
)
268269
model.load_state_dict(state_dict["model"])
269270
@@ -274,7 +275,9 @@ By default, DCP saves and loads a distributed ``state_dict`` in Single Program M
274275
275276
Conclusion
276277
----------
277-
In conclusion, we have learned how to use DCP's :func:`save_state_dict` and :func:`load_state_dict` APIs, as well as how they are different form :func:`torch.save` and :func:`torch.load`.
278+
In conclusion, we have learned how to use DCP's :func:`save` and :func:`load` APIs, as well as how they are different form :func:`torch.save` and :func:`torch.load`.
279+
Additionally, we've learned how to use :func:`get_state_dict` and :func:`set_state_dict` to automatically manage parallelism-specific FQN's and defaults during state dict
280+
generation and loading.
278281

279282
For more information, please see the following:
280283

0 commit comments

Comments
 (0)