Skip to content

Commit 52c96e2

Browse files
author
Svetlana Karslioglu
authored
Merge branch 'main' into add_amx_doc
2 parents 622542a + 6dc7f82 commit 52c96e2

File tree

8 files changed

+1877
-0
lines changed

8 files changed

+1877
-0
lines changed

Diff for: _static/img/half_cheetah.gif

2.26 MB
Loading

Diff for: advanced_source/cpp_cuda_graphs.rst

+193
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
Using CUDA Graphs in PyTorch C++ API
2+
====================================
3+
4+
.. note::
5+
|edit| View and edit this tutorial in `GitHub <https://github.com/pytorch/tutorials/blob/main/advanced_source/cpp_cuda_graphs.rst>`__. The full source code is available on `GitHub <https://github.com/pytorch/tutorials/blob/main/advanced_source/cpp_cuda_graphs>`__.
6+
7+
Prerequisites:
8+
9+
- `Using the PyTorch C++ Frontend <../advanced_source/cpp_frontend.html>`__
10+
- `CUDA semantics <https://pytorch.org/docs/master/notes/cuda.html>`__
11+
- Pytorch 2.0 or later
12+
- CUDA 11 or later
13+
14+
NVIDIA’s CUDA Graphs have been a part of CUDA Toolkit library since the
15+
release of `version 10 <https://developer.nvidia.com/blog/cuda-graphs/>`_.
16+
They are capable of greatly reducing the CPU overhead increasing the
17+
performance of applications.
18+
19+
In this tutorial, we will be focusing on using CUDA Graphs for `C++
20+
frontend of PyTorch <https://pytorch.org/tutorials/advanced/cpp_frontend.html>`_.
21+
The C++ frontend is mostly utilized in production and deployment applications which
22+
are important parts of PyTorch use cases. Since `the first appearance
23+
<https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/>`_
24+
the CUDA Graphs won users’ and developer’s hearts for being a very performant
25+
and at the same time simple-to-use tool. In fact, CUDA Graphs are used by default
26+
in ``torch.compile`` of PyTorch 2.0 to boost the productivity of training and inference.
27+
28+
We would like to demonstrate CUDA Graphs usage on PyTorch’s `MNIST
29+
example <https://github.com/pytorch/examples/tree/main/cpp/mnist>`_.
30+
The usage of CUDA Graphs in LibTorch (C++ Frontend) is very similar to its
31+
`Python counterpart <https://pytorch.org/docs/main/notes/cuda.html#cuda-graphs>`_
32+
but with some differences in syntax and functionality.
33+
34+
Getting Started
35+
---------------
36+
37+
The main training loop consists of the several steps and depicted in the
38+
following code chunk:
39+
40+
.. code-block:: cpp
41+
42+
for (auto& batch : data_loader) {
43+
auto data = batch.data.to(device);
44+
auto targets = batch.target.to(device);
45+
optimizer.zero_grad();
46+
auto output = model.forward(data);
47+
auto loss = torch::nll_loss(output, targets);
48+
loss.backward();
49+
optimizer.step();
50+
}
51+
52+
The example above includes a forward pass, a backward pass, and weight updates.
53+
54+
In this tutorial, we will be applying CUDA Graph on all the compute steps through the whole-network
55+
graph capture. But before doing so, we need to slightly modify the source code. What we need
56+
to do is preallocate tensors for reusing them in the main training loop. Here is an example
57+
implementation:
58+
59+
.. code-block:: cpp
60+
61+
torch::TensorOptions FloatCUDA =
62+
torch::TensorOptions(device).dtype(torch::kFloat);
63+
torch::TensorOptions LongCUDA =
64+
torch::TensorOptions(device).dtype(torch::kLong);
65+
66+
torch::Tensor data = torch::zeros({kTrainBatchSize, 1, 28, 28}, FloatCUDA);
67+
torch::Tensor targets = torch::zeros({kTrainBatchSize}, LongCUDA);
68+
torch::Tensor output = torch::zeros({1}, FloatCUDA);
69+
torch::Tensor loss = torch::zeros({1}, FloatCUDA);
70+
71+
for (auto& batch : data_loader) {
72+
data.copy_(batch.data);
73+
targets.copy_(batch.target);
74+
training_step(model, optimizer, data, targets, output, loss);
75+
}
76+
77+
Where ``training_step`` simply consists of forward and backward passes with corresponding optimizer calls:
78+
79+
.. code-block:: cpp
80+
81+
void training_step(
82+
Net& model,
83+
torch::optim::Optimizer& optimizer,
84+
torch::Tensor& data,
85+
torch::Tensor& targets,
86+
torch::Tensor& output,
87+
torch::Tensor& loss) {
88+
optimizer.zero_grad();
89+
output = model.forward(data);
90+
loss = torch::nll_loss(output, targets);
91+
loss.backward();
92+
optimizer.step();
93+
}
94+
95+
PyTorch’s CUDA Graphs API is relying on Stream Capture which in our case would be used like this:
96+
97+
.. code-block:: cpp
98+
99+
at::cuda::CUDAGraph graph;
100+
at::cuda::CUDAStream captureStream = at::cuda::getStreamFromPool();
101+
at::cuda::setCurrentCUDAStream(captureStream);
102+
103+
graph.capture_begin();
104+
training_step(model, optimizer, data, targets, output, loss);
105+
graph.capture_end();
106+
107+
Before the actual graph capture, it is important to run several warm-up iterations on side stream to
108+
prepare CUDA cache as well as CUDA libraries (like CUBLAS and CUDNN) that will be used during
109+
the training:
110+
111+
.. code-block:: cpp
112+
113+
at::cuda::CUDAStream warmupStream = at::cuda::getStreamFromPool();
114+
at::cuda::setCurrentCUDAStream(warmupStream);
115+
for (int iter = 0; iter < num_warmup_iters; iter++) {
116+
training_step(model, optimizer, data, targets, output, loss);
117+
}
118+
119+
After the successful graph capture, we can replace ``training_step(model, optimizer, data, targets, output, loss);``
120+
call via ``graph.replay();`` to do the training step.
121+
122+
Training Results
123+
----------------
124+
125+
Taking the code for a spin we can see the following output from ordinary non-graphed training:
126+
127+
.. code-block:: shell
128+
129+
$ time ./mnist
130+
Train Epoch: 1 [59584/60000] Loss: 0.3921
131+
Test set: Average loss: 0.2051 | Accuracy: 0.938
132+
Train Epoch: 2 [59584/60000] Loss: 0.1826
133+
Test set: Average loss: 0.1273 | Accuracy: 0.960
134+
Train Epoch: 3 [59584/60000] Loss: 0.1796
135+
Test set: Average loss: 0.1012 | Accuracy: 0.968
136+
Train Epoch: 4 [59584/60000] Loss: 0.1603
137+
Test set: Average loss: 0.0869 | Accuracy: 0.973
138+
Train Epoch: 5 [59584/60000] Loss: 0.2315
139+
Test set: Average loss: 0.0736 | Accuracy: 0.978
140+
Train Epoch: 6 [59584/60000] Loss: 0.0511
141+
Test set: Average loss: 0.0704 | Accuracy: 0.977
142+
Train Epoch: 7 [59584/60000] Loss: 0.0802
143+
Test set: Average loss: 0.0654 | Accuracy: 0.979
144+
Train Epoch: 8 [59584/60000] Loss: 0.0774
145+
Test set: Average loss: 0.0604 | Accuracy: 0.980
146+
Train Epoch: 9 [59584/60000] Loss: 0.0669
147+
Test set: Average loss: 0.0544 | Accuracy: 0.984
148+
Train Epoch: 10 [59584/60000] Loss: 0.0219
149+
Test set: Average loss: 0.0517 | Accuracy: 0.983
150+
151+
real 0m44.287s
152+
user 0m44.018s
153+
sys 0m1.116s
154+
155+
While the training with the CUDA Graph produces the following output:
156+
157+
.. code-block:: shell
158+
159+
$ time ./mnist --use-train-graph
160+
Train Epoch: 1 [59584/60000] Loss: 0.4092
161+
Test set: Average loss: 0.2037 | Accuracy: 0.938
162+
Train Epoch: 2 [59584/60000] Loss: 0.2039
163+
Test set: Average loss: 0.1274 | Accuracy: 0.961
164+
Train Epoch: 3 [59584/60000] Loss: 0.1779
165+
Test set: Average loss: 0.1017 | Accuracy: 0.968
166+
Train Epoch: 4 [59584/60000] Loss: 0.1559
167+
Test set: Average loss: 0.0871 | Accuracy: 0.972
168+
Train Epoch: 5 [59584/60000] Loss: 0.2240
169+
Test set: Average loss: 0.0735 | Accuracy: 0.977
170+
Train Epoch: 6 [59584/60000] Loss: 0.0520
171+
Test set: Average loss: 0.0710 | Accuracy: 0.978
172+
Train Epoch: 7 [59584/60000] Loss: 0.0935
173+
Test set: Average loss: 0.0666 | Accuracy: 0.979
174+
Train Epoch: 8 [59584/60000] Loss: 0.0744
175+
Test set: Average loss: 0.0603 | Accuracy: 0.981
176+
Train Epoch: 9 [59584/60000] Loss: 0.0762
177+
Test set: Average loss: 0.0547 | Accuracy: 0.983
178+
Train Epoch: 10 [59584/60000] Loss: 0.0207
179+
Test set: Average loss: 0.0525 | Accuracy: 0.983
180+
181+
real 0m6.952s
182+
user 0m7.048s
183+
sys 0m0.619s
184+
185+
Conclusion
186+
----------
187+
188+
As we can see, just by applying a CUDA Graph on the `MNIST example
189+
<https://github.com/pytorch/examples/tree/main/cpp/mnist>`_ we were able to gain the performance
190+
by more than six times for training. This kind of large performance improvement was achievable due to
191+
the small model size. In case of larger models with heavy GPU usage, the CPU overhead is less impactful
192+
so the improvement will be smaller. Nevertheless, it is always advantageous to use CUDA Graphs to
193+
gain the performance of GPUs.

Diff for: advanced_source/cpp_cuda_graphs/CMakeLists.txt

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
2+
project(mnist)
3+
set(CMAKE_CXX_STANDARD 17)
4+
5+
find_package(Torch REQUIRED)
6+
find_package(Threads REQUIRED)
7+
8+
option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON)
9+
if (DOWNLOAD_MNIST)
10+
message(STATUS "Downloading MNIST dataset")
11+
execute_process(
12+
COMMAND python ${CMAKE_CURRENT_LIST_DIR}/../tools/download_mnist.py
13+
-d ${CMAKE_BINARY_DIR}/data
14+
ERROR_VARIABLE DOWNLOAD_ERROR)
15+
if (DOWNLOAD_ERROR)
16+
message(FATAL_ERROR "Error downloading MNIST dataset: ${DOWNLOAD_ERROR}")
17+
endif()
18+
endif()
19+
20+
add_executable(mnist mnist.cpp)
21+
target_compile_features(mnist PUBLIC cxx_range_for)
22+
target_link_libraries(mnist ${TORCH_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
23+
24+
if (MSVC)
25+
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
26+
add_custom_command(TARGET mnist
27+
POST_BUILD
28+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
29+
${TORCH_DLLS}
30+
$<TARGET_FILE_DIR:mnist>)
31+
endif (MSVC)

Diff for: advanced_source/cpp_cuda_graphs/README.md

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# MNIST Example with the PyTorch C++ Frontend
2+
3+
This folder contains an example of training a computer vision model to recognize
4+
digits in images from the MNIST dataset, using the PyTorch C++ frontend.
5+
6+
The entire training code is contained in `mnist.cpp`.
7+
8+
To build the code, run the following commands from your terminal:
9+
10+
```shell
11+
$ cd mnist
12+
$ mkdir build
13+
$ cd build
14+
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
15+
$ make
16+
```
17+
18+
where `/path/to/libtorch` should be the path to the unzipped _LibTorch_
19+
distribution, which you can get from the [PyTorch
20+
homepage](https://pytorch.org/get-started/locally/).
21+
22+
Execute the compiled binary to train the model:
23+
24+
```shell
25+
$ ./mnist
26+
Train Epoch: 1 [59584/60000] Loss: 0.4232
27+
Test set: Average loss: 0.1989 | Accuracy: 0.940
28+
Train Epoch: 2 [59584/60000] Loss: 0.1926
29+
Test set: Average loss: 0.1338 | Accuracy: 0.959
30+
Train Epoch: 3 [59584/60000] Loss: 0.1390
31+
Test set: Average loss: 0.0997 | Accuracy: 0.969
32+
Train Epoch: 4 [59584/60000] Loss: 0.1239
33+
Test set: Average loss: 0.0875 | Accuracy: 0.972
34+
...
35+
```
36+
37+
For running with CUDA Graphs add `--use-train-graph` and/or `--use-test-graph`
38+
for training and testing passes respectively.

0 commit comments

Comments
 (0)