|
| 1 | +Using CUDA Graphs in PyTorch C++ API |
| 2 | +==================================== |
| 3 | + |
| 4 | +.. note:: |
| 5 | + |edit| View and edit this tutorial in `GitHub <https://github.com/pytorch/tutorials/blob/main/advanced_source/cpp_cuda_graphs.rst>`__. The full source code is available on `GitHub <https://github.com/pytorch/tutorials/blob/main/advanced_source/cpp_cuda_graphs>`__. |
| 6 | + |
| 7 | +Prerequisites: |
| 8 | + |
| 9 | +- `Using the PyTorch C++ Frontend <../advanced_source/cpp_frontend.html>`__ |
| 10 | +- `CUDA semantics <https://pytorch.org/docs/master/notes/cuda.html>`__ |
| 11 | +- Pytorch 2.0 or later |
| 12 | +- CUDA 11 or later |
| 13 | + |
| 14 | +NVIDIA’s CUDA Graphs have been a part of CUDA Toolkit library since the |
| 15 | +release of `version 10 <https://developer.nvidia.com/blog/cuda-graphs/>`_. |
| 16 | +They are capable of greatly reducing the CPU overhead increasing the |
| 17 | +performance of applications. |
| 18 | + |
| 19 | +In this tutorial, we will be focusing on using CUDA Graphs for `C++ |
| 20 | +frontend of PyTorch <https://pytorch.org/tutorials/advanced/cpp_frontend.html>`_. |
| 21 | +The C++ frontend is mostly utilized in production and deployment applications which |
| 22 | +are important parts of PyTorch use cases. Since `the first appearance |
| 23 | +<https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/>`_ |
| 24 | +the CUDA Graphs won users’ and developer’s hearts for being a very performant |
| 25 | +and at the same time simple-to-use tool. In fact, CUDA Graphs are used by default |
| 26 | +in ``torch.compile`` of PyTorch 2.0 to boost the productivity of training and inference. |
| 27 | + |
| 28 | +We would like to demonstrate CUDA Graphs usage on PyTorch’s `MNIST |
| 29 | +example <https://github.com/pytorch/examples/tree/main/cpp/mnist>`_. |
| 30 | +The usage of CUDA Graphs in LibTorch (C++ Frontend) is very similar to its |
| 31 | +`Python counterpart <https://pytorch.org/docs/main/notes/cuda.html#cuda-graphs>`_ |
| 32 | +but with some differences in syntax and functionality. |
| 33 | + |
| 34 | +Getting Started |
| 35 | +--------------- |
| 36 | + |
| 37 | +The main training loop consists of the several steps and depicted in the |
| 38 | +following code chunk: |
| 39 | + |
| 40 | +.. code-block:: cpp |
| 41 | +
|
| 42 | + for (auto& batch : data_loader) { |
| 43 | + auto data = batch.data.to(device); |
| 44 | + auto targets = batch.target.to(device); |
| 45 | + optimizer.zero_grad(); |
| 46 | + auto output = model.forward(data); |
| 47 | + auto loss = torch::nll_loss(output, targets); |
| 48 | + loss.backward(); |
| 49 | + optimizer.step(); |
| 50 | + } |
| 51 | +
|
| 52 | +The example above includes a forward pass, a backward pass, and weight updates. |
| 53 | + |
| 54 | +In this tutorial, we will be applying CUDA Graph on all the compute steps through the whole-network |
| 55 | +graph capture. But before doing so, we need to slightly modify the source code. What we need |
| 56 | +to do is preallocate tensors for reusing them in the main training loop. Here is an example |
| 57 | +implementation: |
| 58 | + |
| 59 | +.. code-block:: cpp |
| 60 | +
|
| 61 | + torch::TensorOptions FloatCUDA = |
| 62 | + torch::TensorOptions(device).dtype(torch::kFloat); |
| 63 | + torch::TensorOptions LongCUDA = |
| 64 | + torch::TensorOptions(device).dtype(torch::kLong); |
| 65 | +
|
| 66 | + torch::Tensor data = torch::zeros({kTrainBatchSize, 1, 28, 28}, FloatCUDA); |
| 67 | + torch::Tensor targets = torch::zeros({kTrainBatchSize}, LongCUDA); |
| 68 | + torch::Tensor output = torch::zeros({1}, FloatCUDA); |
| 69 | + torch::Tensor loss = torch::zeros({1}, FloatCUDA); |
| 70 | +
|
| 71 | + for (auto& batch : data_loader) { |
| 72 | + data.copy_(batch.data); |
| 73 | + targets.copy_(batch.target); |
| 74 | + training_step(model, optimizer, data, targets, output, loss); |
| 75 | + } |
| 76 | +
|
| 77 | +Where ``training_step`` simply consists of forward and backward passes with corresponding optimizer calls: |
| 78 | + |
| 79 | +.. code-block:: cpp |
| 80 | +
|
| 81 | + void training_step( |
| 82 | + Net& model, |
| 83 | + torch::optim::Optimizer& optimizer, |
| 84 | + torch::Tensor& data, |
| 85 | + torch::Tensor& targets, |
| 86 | + torch::Tensor& output, |
| 87 | + torch::Tensor& loss) { |
| 88 | + optimizer.zero_grad(); |
| 89 | + output = model.forward(data); |
| 90 | + loss = torch::nll_loss(output, targets); |
| 91 | + loss.backward(); |
| 92 | + optimizer.step(); |
| 93 | + } |
| 94 | +
|
| 95 | +PyTorch’s CUDA Graphs API is relying on Stream Capture which in our case would be used like this: |
| 96 | + |
| 97 | +.. code-block:: cpp |
| 98 | +
|
| 99 | + at::cuda::CUDAGraph graph; |
| 100 | + at::cuda::CUDAStream captureStream = at::cuda::getStreamFromPool(); |
| 101 | + at::cuda::setCurrentCUDAStream(captureStream); |
| 102 | +
|
| 103 | + graph.capture_begin(); |
| 104 | + training_step(model, optimizer, data, targets, output, loss); |
| 105 | + graph.capture_end(); |
| 106 | +
|
| 107 | +Before the actual graph capture, it is important to run several warm-up iterations on side stream to |
| 108 | +prepare CUDA cache as well as CUDA libraries (like CUBLAS and CUDNN) that will be used during |
| 109 | +the training: |
| 110 | + |
| 111 | +.. code-block:: cpp |
| 112 | +
|
| 113 | + at::cuda::CUDAStream warmupStream = at::cuda::getStreamFromPool(); |
| 114 | + at::cuda::setCurrentCUDAStream(warmupStream); |
| 115 | + for (int iter = 0; iter < num_warmup_iters; iter++) { |
| 116 | + training_step(model, optimizer, data, targets, output, loss); |
| 117 | + } |
| 118 | +
|
| 119 | +After the successful graph capture, we can replace ``training_step(model, optimizer, data, targets, output, loss);`` |
| 120 | +call via ``graph.replay();`` to do the training step. |
| 121 | + |
| 122 | +Training Results |
| 123 | +---------------- |
| 124 | + |
| 125 | +Taking the code for a spin we can see the following output from ordinary non-graphed training: |
| 126 | + |
| 127 | +.. code-block:: shell |
| 128 | +
|
| 129 | + $ time ./mnist |
| 130 | + Train Epoch: 1 [59584/60000] Loss: 0.3921 |
| 131 | + Test set: Average loss: 0.2051 | Accuracy: 0.938 |
| 132 | + Train Epoch: 2 [59584/60000] Loss: 0.1826 |
| 133 | + Test set: Average loss: 0.1273 | Accuracy: 0.960 |
| 134 | + Train Epoch: 3 [59584/60000] Loss: 0.1796 |
| 135 | + Test set: Average loss: 0.1012 | Accuracy: 0.968 |
| 136 | + Train Epoch: 4 [59584/60000] Loss: 0.1603 |
| 137 | + Test set: Average loss: 0.0869 | Accuracy: 0.973 |
| 138 | + Train Epoch: 5 [59584/60000] Loss: 0.2315 |
| 139 | + Test set: Average loss: 0.0736 | Accuracy: 0.978 |
| 140 | + Train Epoch: 6 [59584/60000] Loss: 0.0511 |
| 141 | + Test set: Average loss: 0.0704 | Accuracy: 0.977 |
| 142 | + Train Epoch: 7 [59584/60000] Loss: 0.0802 |
| 143 | + Test set: Average loss: 0.0654 | Accuracy: 0.979 |
| 144 | + Train Epoch: 8 [59584/60000] Loss: 0.0774 |
| 145 | + Test set: Average loss: 0.0604 | Accuracy: 0.980 |
| 146 | + Train Epoch: 9 [59584/60000] Loss: 0.0669 |
| 147 | + Test set: Average loss: 0.0544 | Accuracy: 0.984 |
| 148 | + Train Epoch: 10 [59584/60000] Loss: 0.0219 |
| 149 | + Test set: Average loss: 0.0517 | Accuracy: 0.983 |
| 150 | +
|
| 151 | + real 0m44.287s |
| 152 | + user 0m44.018s |
| 153 | + sys 0m1.116s |
| 154 | +
|
| 155 | +While the training with the CUDA Graph produces the following output: |
| 156 | + |
| 157 | +.. code-block:: shell |
| 158 | +
|
| 159 | + $ time ./mnist --use-train-graph |
| 160 | + Train Epoch: 1 [59584/60000] Loss: 0.4092 |
| 161 | + Test set: Average loss: 0.2037 | Accuracy: 0.938 |
| 162 | + Train Epoch: 2 [59584/60000] Loss: 0.2039 |
| 163 | + Test set: Average loss: 0.1274 | Accuracy: 0.961 |
| 164 | + Train Epoch: 3 [59584/60000] Loss: 0.1779 |
| 165 | + Test set: Average loss: 0.1017 | Accuracy: 0.968 |
| 166 | + Train Epoch: 4 [59584/60000] Loss: 0.1559 |
| 167 | + Test set: Average loss: 0.0871 | Accuracy: 0.972 |
| 168 | + Train Epoch: 5 [59584/60000] Loss: 0.2240 |
| 169 | + Test set: Average loss: 0.0735 | Accuracy: 0.977 |
| 170 | + Train Epoch: 6 [59584/60000] Loss: 0.0520 |
| 171 | + Test set: Average loss: 0.0710 | Accuracy: 0.978 |
| 172 | + Train Epoch: 7 [59584/60000] Loss: 0.0935 |
| 173 | + Test set: Average loss: 0.0666 | Accuracy: 0.979 |
| 174 | + Train Epoch: 8 [59584/60000] Loss: 0.0744 |
| 175 | + Test set: Average loss: 0.0603 | Accuracy: 0.981 |
| 176 | + Train Epoch: 9 [59584/60000] Loss: 0.0762 |
| 177 | + Test set: Average loss: 0.0547 | Accuracy: 0.983 |
| 178 | + Train Epoch: 10 [59584/60000] Loss: 0.0207 |
| 179 | + Test set: Average loss: 0.0525 | Accuracy: 0.983 |
| 180 | +
|
| 181 | + real 0m6.952s |
| 182 | + user 0m7.048s |
| 183 | + sys 0m0.619s |
| 184 | +
|
| 185 | +Conclusion |
| 186 | +---------- |
| 187 | + |
| 188 | +As we can see, just by applying a CUDA Graph on the `MNIST example |
| 189 | +<https://github.com/pytorch/examples/tree/main/cpp/mnist>`_ we were able to gain the performance |
| 190 | +by more than six times for training. This kind of large performance improvement was achievable due to |
| 191 | +the small model size. In case of larger models with heavy GPU usage, the CPU overhead is less impactful |
| 192 | +so the improvement will be smaller. Nevertheless, it is always advantageous to use CUDA Graphs to |
| 193 | +gain the performance of GPUs. |
0 commit comments