Skip to content

Commit 2585b77

Browse files
authored
Merge pull request #1116 from tanayvarshney/docs
Added Triton deployment instructions to documentation
2 parents cf8da43 + 2c01adc commit 2585b77

File tree

3 files changed

+234
-18
lines changed

3 files changed

+234
-18
lines changed

docsrc/index.rst

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Getting Started
2828
* :ref:`use_from_pytorch`
2929
* :ref:`runtime`
3030
* :ref:`using_dla`
31+
* :ref:`deploy_torch_tensorrt_to_triton`
3132

3233
.. toctree::
3334
:caption: Getting Started
@@ -43,6 +44,7 @@ Getting Started
4344
tutorials/use_from_pytorch
4445
tutorials/runtime
4546
tutorials/using_dla
47+
tutorials/deploy_torch_tensorrt_to_triton
4648

4749
.. toctree::
4850
:caption: Notebooks
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
Deploying a Torch-TensorRT model (to Triton)
2+
============================================
3+
4+
Optimization and deployment go hand in hand in a discussion about Machine
5+
Learning infrastructure. Once network level optimzation are done
6+
to get the maximum performance, the next step would be to deploy it.
7+
8+
However, serving this optimized model comes with it's own set of considerations
9+
and challenges like: building an infrastructure to support concorrent model
10+
executions, supporting clients over HTTP or gRPC and more.
11+
12+
The `Triton Inference Server <https://github.com/triton-inference-server/server>`__
13+
solves the aforementioned and more. Let's discuss step-by-step, the process of
14+
optimizing a model with Torch-TensorRT, deploying it on Triton Inference
15+
Server, and building a client to query the model.
16+
17+
Step 1: Optimize your model with Torch-TensorRT
18+
-----------------------------------------------
19+
20+
Most Torch-TensorRT users will be familiar with this step. For the purpose of
21+
this demonstration, we will be using a ResNet50 model from Torchhub.
22+
23+
Let’s first pull the `NGC PyTorch Docker container <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`__. You may need to create
24+
an account and get the API key from `here <https://ngc.nvidia.com/setup/>`__.
25+
Sign up and login with your key (follow the instructions
26+
`here <https://ngc.nvidia.com/setup/api-key>`__ after signing up).
27+
28+
::
29+
30+
# <xx.xx> is the yy:mm for the publishing tag for NVIDIA's Pytorch
31+
# container; eg. 22.04
32+
33+
docker run -it --gpus all -v ${PWD}:/scratch_space nvcr.io/nvidia/pytorch:<xx.xx>-py3
34+
cd /scratch_space
35+
36+
Once inside the container, we can proceed to download a ResNet model from
37+
Torchhub and optimize it with Torch-TensorRT.
38+
39+
::
40+
41+
import torch
42+
import torch_tensorrt
43+
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
44+
45+
# load model
46+
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True).eval().to("cuda")
47+
48+
# Compile with Torch TensorRT;
49+
trt_model = torch_tensorrt.compile(model,
50+
inputs= [torch_tensorrt.Input((1, 3, 224, 224))],
51+
enabled_precisions= { torch.half} # Run with FP32
52+
)
53+
54+
# Save the model
55+
torch.jit.save(trt_model, "model.pt")
56+
57+
After copying the model, exit the container. The next step in the process
58+
is to set up a Triton Inference Server.
59+
60+
Step 2: Set Up Triton Inference Server
61+
--------------------------------------
62+
63+
If you are new to the Triton Inference Server and want to learn more, we
64+
highly recommend to checking our `Github
65+
Repository <https://github.com/triton-inference-server>`__.
66+
67+
To use Triton, we need to make a model repository. A model repository, as the
68+
name suggested, is a repository of the models the Inference server hosts. While
69+
Triton can serve models from multiple repositories, in this example, we will
70+
discuss the simplest possible form of the model repository.
71+
72+
The structure of this repository should look something like this:
73+
74+
::
75+
76+
model_repository
77+
|
78+
+-- resnet50
79+
|
80+
+-- config.pbtxt
81+
+-- 1
82+
|
83+
+-- model.pt
84+
85+
There are two files that Triton requires to serve the model: the model itself
86+
and a model configuration file which is typically provided in ``config.pbtxt``.
87+
For the model we prepared in step 1, the following configuration can be used:
88+
89+
::
90+
91+
name: "resnet50"
92+
platform: "pytorch_libtorch"
93+
max_batch_size : 0
94+
input [
95+
{
96+
name: "input__0"
97+
data_type: TYPE_FP32
98+
dims: [ 3, 224, 224 ]
99+
reshape { shape: [ 1, 3, 224, 224 ] }
100+
}
101+
]
102+
output [
103+
{
104+
name: "output__0"
105+
data_type: TYPE_FP32
106+
dims: [ 1, 1000 ,1, 1]
107+
reshape { shape: [ 1, 1000 ] }
108+
}
109+
]
110+
111+
The ``config.pbtxt`` file is used to describe the exact model configuration
112+
with details like the names and shapes of the input and output layer(s),
113+
datatypes, scheduling and batching details and more. If you are new to Triton,
114+
we highly encourage you to check out this `section of our
115+
documentation <https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md>`__
116+
for more details.
117+
118+
With the model repository setup, we can proceed to launch the Triton server
119+
with the docker command below. Refer `this page <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver>`__ for the pull tag for the container.
120+
121+
::
122+
123+
# Make sure that the TensorRT version in the Triton container
124+
# and TensorRT version in the environment used to optimize the model
125+
# are the same.
126+
127+
docker run --gpus all --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v /full/path/to/the_model_repository/model_repository:/models nvcr.io/nvidia/tritonserver:<xx.yy>-py3 tritonserver --model-repository=/models
128+
129+
This should spin up a Triton Inference server. Next step, building a simple
130+
http client to query the server.
131+
132+
Step 3: Building a Triton Client to Query the Server
133+
----------------------------------------------------
134+
135+
Before proceeding, make sure to have a sample image on hand. If you don't
136+
have one, download an example image to test inference. In this section, we
137+
will be going over a very basic client. For a variety of more fleshed out
138+
examples, refer to the `Triton Client Repository <https://github.com/triton-inference-server/client/tree/main/src/python/examples>`__
139+
140+
::
141+
142+
wget -O img1.jpg "https://www.hakaimagazine.com/wp-content/uploads/header-gulf-birds.jpg"
143+
144+
We then need to install dependencies for building a python client. These will
145+
change from client to client. For a full list of all languages supported by Triton,
146+
please refer to `Triton's client repository <https://github.com/triton-inference-server/client>`__.
147+
148+
::
149+
150+
pip install torchvision
151+
pip install attrdict
152+
pip install nvidia-pyindex
153+
pip install tritonclient[all]
154+
155+
Let's jump into the client. Firstly, we write a small preprocessing function to
156+
resize and normalize the query image.
157+
158+
::
159+
160+
import numpy as np
161+
from torchvision import transforms
162+
from PIL import Image
163+
import tritonclient.http as httpclient
164+
from tritonclient.utils import triton_to_np_dtype
165+
166+
# preprocessing function
167+
def rn50_preprocess(img_path="img1.jpg"):
168+
img = Image.open(img_path)
169+
preprocess = transforms.Compose([
170+
transforms.Resize(256),
171+
transforms.CenterCrop(224),
172+
transforms.ToTensor(),
173+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
174+
])
175+
return preprocess(img).numpy()
176+
177+
transformed_img = rn50_preprocess()
178+
179+
Building a client requires three basic points. Firstly, we setup a connection
180+
with the Triton Inference Server.
181+
182+
::
183+
184+
# Setting up client
185+
client = httpclient.InferenceServerClient(url="localhost:8000")
186+
187+
Secondly, we specify the names of the input and output layer(s) of our model.
188+
189+
::
190+
191+
inputs = httpclient.InferInput("input__0", transformed_img.shape, datatype="FP32")
192+
inputs.set_data_from_numpy(transformed_img, binary_data=True)
193+
194+
outputs = httpclient.InferRequestedOutput("output__0", binary_data=True, class_count=1000)
195+
196+
Lastly, we send an inference request to the Triton Inference Server.
197+
198+
::
199+
200+
# Querying the server
201+
results = client.infer(model_name="resnet50", inputs=[inputs], outputs=[outputs])
202+
inference_output = results.as_numpy('output__0')
203+
print(inference_output[:5])
204+
205+
The output of the same should look like below:
206+
207+
::
208+
209+
[b'12.468750:90' b'11.523438:92' b'9.664062:14' b'8.429688:136'
210+
b'8.234375:11']
211+
212+
The output format here is ``<confidence_score>:<classification_index>``.
213+
To learn how to map these to the label names and more, refer to our
214+
`documentation <https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_classification.md>`__.

notebooks/dynamic-shapes.ipynb

+18-18
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"metadata": {},
88
"outputs": [],
99
"source": [
10-
"# Copyright 2020 NVIDIA Corporation. All Rights Reserved.\n",
10+
"# Copyright 2022 NVIDIA Corporation. All Rights Reserved.\n",
1111
"#\n",
1212
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
1313
"# you may not use this file except in compliance with the License.\n",
@@ -36,14 +36,14 @@
3636
"id": "73703695",
3737
"metadata": {},
3838
"source": [
39-
"Torch-TensorRT is a compiler for PyTorch/TorchScript, targeting NVIDIA GPUs via NVIDIA's TensorRT Deep Learning Optimizer and Runtime. Unlike PyTorch's Just-In-Time (JIT) compiler, Torch-TensorRT is an Ahead-of-Time (AOT) compiler, meaning that before you deploy your TorchScript code, you go through an explicit compile step to convert a standard TorchScript program into an module targeting a TensorRT engine. Torch-TensorRT operates as a PyTorch extention and compiles modules that integrate into the JIT runtime seamlessly. After compilation using the optimized graph should feel no different than running a TorchScript module. You also have access to TensorRT's suite of configurations at compile time, so you are able to specify operating precision (FP32/FP16/INT8) and other settings for your module.\n",
39+
"Torch-TensorRT is a compiler for PyTorch/TorchScript, targeting NVIDIA GPUs via NVIDIA's TensorRT Deep Learning Optimizer and Runtime. Unlike PyTorch's Just-In-Time (JIT) compiler, Torch-TensorRT is an Ahead-of-Time (AOT) compiler, meaning that before you deploy your TorchScript code, you go through an explicit compile step to convert a standard TorchScript program into a module targeting a TensorRT engine. Torch-TensorRT operates as a PyTorch extension and compiles modules that integrate into the JIT runtime seamlessly. After compilation, using the optimized graph should feel no different than running a TorchScript module. You also have access to TensorRT's suite of configurations at compile-time, so you are able to specify operating precision (FP32/FP16/INT8) and other settings for your module.\n",
4040
"\n",
41-
"We highly encorage users to use our NVIDIA's [PyTorch container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) to run this notebook. It comes packaged with a host of NVIDIA libraries and optimizations to widely used third party libraries. This container is tested and updated on a monthly cadence!\n",
41+
"We highly encourage users to run this notebook using our NVIDIA's [PyTorch container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch). It comes packaged with a host of NVIDIA libraries and optimizations to widely used third-party libraries. In addition, this container is tested and updated on a monthly cadence!\n",
4242
"\n",
4343
"This notebook has the following sections:\n",
44-
"1. [TL;DR Explanation](#1)\n",
45-
"1. [Setting up the model](#2)\n",
46-
"1. [Working with Dynamic shapes in Torch TRT](#3)"
44+
"1. TL;DR Explanation\n",
45+
"1. Setting up the model\n",
46+
"1. Working with Dynamic shapes in Torch TRT]"
4747
]
4848
},
4949
{
@@ -633,7 +633,7 @@
633633
"id": "21402d53",
634634
"metadata": {},
635635
"source": [
636-
"Let's test our util functions on the model we have set up, starting with simple predictions"
636+
"Let's test our util functions on the model we have set up, starting with simple predictions."
637637
]
638638
},
639639
{
@@ -820,19 +820,19 @@
820820
"source": [
821821
"---\n",
822822
"## Working with Dynamic shapes in Torch TRT\n",
823-
"\n",
824-
"Enabling \"Dynamic Shaped\" tensors to be used is essentially enabling the ability to defer defining the shape of tensors until runetime. Torch TensorRT simply leverages TensorRT's Dynamic shape support. You can read more about TensorRT's implementation in the [TensorRT Documentation](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#work_dynamic_shapes).\n",
825-
"\n",
823+
" \n",
824+
"Enabling \"Dynamic Shaped\" tensors to be used is essentially enabling the ability to defer defining the shape of tensors until run-time. Torch TensorRT simply leverages TensorRT's Dynamic shape support. You can read more about TensorRT's implementation in the [TensorRT Documentation](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#work_dynamic_shapes).\n",
825+
" \n",
826826
"#### How can you use this feature?\n",
827-
"\n",
827+
" \n",
828828
"To make use of dynamic shapes, you need to provide three shapes:\n",
829829
"* `min_shape`: The minimum size of the tensor considered for optimizations.\n",
830-
"* `opt_shape`: The optimizations will be done with an effort to maximize performance for this shape.\n",
831-
"* `min_shape`: The maximum size of the tensor considered for optimizations.\n",
832-
"\n",
833-
"Generally, users can expect best performance within the specified ranges. Performance for other shapes may be be lower for other shapes (depending on the model ops and GPU used)\n",
834-
"\n",
835-
"In the following example, we will showcase varing batch size, which is the zeroth dimension of our input tensors. As Convolution operations require that the channel dimension be a build-time constant, we won't be changing sizes of other channels in this example, but for models which contain ops conducive to changes in other channels, this functionality can be freely used."
830+
"* `opt_shape`: The optimizations will be done in an effort to maximize performance for this shape.\n",
831+
"* `max_shape`: The maximum size of the tensor considered for optimizations.\n",
832+
" \n",
833+
"Generally, users can expect the best performance within the specified ranges. Performance for other shapes maybe be lower for other shapes (depending on the model ops and GPU used)\n",
834+
" \n",
835+
"In the following example, we will showcase varying batch sizes, which is the zeroth dimension of our input tensors. As Convolution operations require that the channel dimension be a build-time constant, we won't be changing the sizes of other channels in this example, but for models which contain ops conducive to changes in other channels, this functionality can be freely used."
836836
]
837837
},
838838
{
@@ -1015,7 +1015,7 @@
10151015
"name": "python",
10161016
"nbconvert_exporter": "python",
10171017
"pygments_lexer": "ipython3",
1018-
"version": "3.8.13"
1018+
"version": "3.9.6"
10191019
}
10201020
},
10211021
"nbformat": 4,

0 commit comments

Comments
 (0)