Skip to content

Commit f2d1655

Browse files
author
Anurag Dixit
committed
feat: Perf benchmark initial draft
Signed-off-by: Anurag Dixit <[email protected]>
1 parent 5eee16f commit f2d1655

File tree

2 files changed

+309
-0
lines changed

2 files changed

+309
-0
lines changed

Diff for: examples/benchmark/py/config/vgg16.yml

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
---
2+
backend:
3+
- torch
4+
- torch_tensorrt
5+
input:
6+
input0:
7+
- 1
8+
- 3
9+
- 224
10+
- 224
11+
num_of_input: 1
12+
model:
13+
filename: vgg16_traced.jit.pt
14+
name: vgg16
15+
runtime:
16+
device: 0
17+
precision:
18+
- fp32
19+
- fp16

Diff for: examples/benchmark/py/perf_run.py

+290
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
from __future__ import print_function
2+
from __future__ import absolute_import
3+
from __future__ import division
4+
5+
import argparse
6+
import timeit
7+
import numpy as np
8+
import torch.backends.cudnn as cudnn
9+
import yaml
10+
import os
11+
import pandas as pd
12+
13+
# Backend
14+
import torch
15+
import torch_tensorrt as torchtrt
16+
import tensorrt as trt
17+
import pycuda.autoinit
18+
import pycuda.driver as cuda
19+
20+
21+
TRT_LOGGER = trt.Logger()
22+
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
23+
24+
results = []
25+
26+
def run_torch(model, input_tensors, params, precision):
27+
print("Running Torch for precision: ", precision)
28+
29+
iters = 20 if not "iterations" in params else params['iterations']
30+
31+
# Warm up
32+
with torch.no_grad():
33+
for _ in range(20):
34+
features = model(*input_tensors)
35+
36+
torch.cuda.synchronize()
37+
38+
timings = []
39+
with torch.no_grad():
40+
for i in range(iters):
41+
start_time = timeit.default_timer()
42+
features = model(*input_tensors)
43+
torch.cuda.synchronize()
44+
end_time = timeit.default_timer()
45+
meas_time = end_time - start_time
46+
timings.append(meas_time)
47+
print("Iteration {}: {:.6f} s".format(i, end_time - start_time))
48+
49+
printStats("Torch", timings, precision)
50+
51+
def onnx_to_trt_engine(onnx_model, precision):
52+
53+
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, builder.create_builder_config() as config, trt.OnnxParser(network, TRT_LOGGER) as parser, trt.Runtime(TRT_LOGGER) as runtime:
54+
config.max_workspace_size = 1 << 28 # 256MiB
55+
builder.max_batch_size = 1
56+
57+
if precision == 'int8':
58+
config.set_flag(trt.BuilderFlag.INT8)
59+
elif precision == 'fp16' or precision == 'half':
60+
config.set_flag(trt.BuilderFlag.HALF)
61+
62+
plan = builder.build_serialized_network(network, config)
63+
model = runtime.deserialize_cuda_engine(plan)
64+
return model
65+
66+
def run_torch_tensorrt(model, input_tensors, params, precision):
67+
print("Running Torch-TensorRT")
68+
69+
# Compiling Torch-TensorRT model
70+
compile_settings = {
71+
"inputs": input_tensors,
72+
"enabled_precisions": {precision_to_dtype(precision)}
73+
}
74+
75+
model = torchtrt.compile(model, **compile_settings)
76+
77+
iters = 20 if not "iterations" in params else params['iterations']
78+
# Warm up
79+
with torch.no_grad():
80+
for _ in range(20):
81+
features = model(*input_tensors)
82+
83+
torch.cuda.synchronize()
84+
85+
timings = []
86+
with torch.no_grad():
87+
for i in range(iters):
88+
start_time = timeit.default_timer()
89+
features = model(*input_tensors)
90+
torch.cuda.synchronize()
91+
end_time = timeit.default_timer()
92+
meas_time = end_time - start_time
93+
timings.append(meas_time)
94+
print("Iteration {}: {:.6f} s".format(i, end_time - start_time))
95+
96+
printStats("Torch-TensorRT", timings, precision)
97+
98+
def run_tensorrt(model, input_tensors, params, precision):
99+
print("Running TensorRT")
100+
inputs = []
101+
outputs = []
102+
bindings = []
103+
stream = cuda.Stream()
104+
iters = 20 if not "iterations" in params else params['iterations']
105+
106+
if not "batch" in params:
107+
batch_size = 1
108+
else:
109+
batch_size = params['batch_size']
110+
111+
with onnx_to_trt_engine(model, precision) as engine, engine.create_execution_context() as context:
112+
113+
for binding in engine:
114+
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
115+
dtype = trt.nptype(engine.get_binding_dtype(binding))
116+
117+
# Input already allocated in input_tensors
118+
mem = cuda.mem_alloc()
119+
# Allocate host and device buffers
120+
host_mem = cuda.pagelocked_empty(size, dtype)
121+
device_mem = cuda.mem_alloc(host_mem.nbytes)
122+
# Append the device buffer to device bindings.
123+
bindings.append(int(device_mem))
124+
# Append to the appropriate list.
125+
"""
126+
if engine.binding_is_input(binding):
127+
inputs.append(HostDeviceMem(host_mem, device_mem))
128+
else:
129+
outputs.append(HostDeviceMem(host_mem, device_mem))
130+
if not engine.binding_is_input(binding):
131+
outputs.append(cuda.mem_alloc(cuda.pagelocked_empty(size, dtype).nbytes))
132+
else:
133+
bindings.append(input_tensors)
134+
"""
135+
# Warm up
136+
for _ in range(20):
137+
context.execute_async(batch_size, bindings, stream.handle)
138+
139+
stream.synchronize()
140+
141+
for i in range(iters):
142+
start_time = timeit.default_timer()
143+
context.execute_async(batch_size, bindings, stream.handle)
144+
stream.synchronize()
145+
end_time = timeit.default_timer()
146+
meas_time = end_time - start_time
147+
148+
149+
150+
iters = 20 if not "iterations" in params else params['iterations']
151+
# Warm up
152+
with torch.no_grad():
153+
for _ in range(20):
154+
features = model(input_tensors)
155+
156+
torch.cuda.synchronize()
157+
158+
timings = []
159+
with torch.no_grad():
160+
for i in range(iters):
161+
start_time = timeit.default_timer()
162+
features = model(input_tensors)
163+
torch.cuda.synchronize()
164+
end_time = timeit.default_timer()
165+
meas_time = end_time - start_time
166+
timings.append(meas_time)
167+
print("Iteration {}: {:.6f} s".format(i, end_time - start_time))
168+
169+
printStats("TensorRT", timings, precision)
170+
171+
def run(model, input_tensors, params, precision):
172+
for backend in params['backend']:
173+
if backend == 'all':
174+
run_torch(model, input_tensors, params, precision)
175+
run_torch_tensorrt(model, input_tensors, params, precision)
176+
run_tensorrt(model, input_tensors, params, precision)
177+
178+
elif backend == "torch":
179+
run_torch(model, input_tensors, params, precision)
180+
181+
elif backend == "torch_tensorrt":
182+
run_torch_tensorrt(model, input_tensors, params, precision)
183+
184+
elif backend == "tensorrt":
185+
run_tensorrt(model, input_tensors, params, precision)
186+
187+
188+
def printStats(backend, timings, precision, batch_size = 1):
189+
times = np.array(timings)
190+
steps = len(times)
191+
speeds = batch_size / times
192+
time_mean = np.mean(times)
193+
time_med = np.median(times)
194+
time_99th = np.percentile(times, 99)
195+
time_std = np.std(times, ddof=0)
196+
speed_mean = np.mean(speeds)
197+
speed_med = np.median(speeds)
198+
199+
msg = ("\n%s =================================\n"
200+
"batch size=%d, num iterations=%d\n"
201+
" Median FPS: %.1f, mean: %.1f\n"
202+
" Median latency: %.6f, mean: %.6f, 99th_p: %.6f, std_dev: %.6f\n"
203+
) % (backend,
204+
batch_size, steps,
205+
speed_med, speed_mean,
206+
time_med, time_mean, time_99th, time_std)
207+
print(msg)
208+
meas = {
209+
'Backend' : backend,
210+
'precision' : precision,
211+
'Median(FPS)' : speed_med,
212+
'Mean(FPS)' : speed_mean,
213+
'Median-Latency(ms)' : time_med,
214+
'Mean-Latency(ms)' : time_mean,
215+
'99th_p' : time_99th,
216+
'std_dev': time_std
217+
}
218+
results.append(meas)
219+
220+
def read_config(config_file):
221+
with open(config_file, "r") as stream:
222+
try:
223+
params = yaml.safe_load(stream)
224+
except yaml.YAMLError as exc:
225+
print(exc)
226+
return params
227+
228+
def precision_to_dtype(pr):
229+
if pr == 'fp32':
230+
return torch.float
231+
elif pr == 'fp16' or pr == 'half':
232+
return torch.half
233+
else:
234+
return torch.int8
235+
236+
def load_model(params):
237+
model = None
238+
# Load traced model
239+
if "torch" in params['backend'] or "torch_tensorrt" in params['backend']:
240+
model_path = os.path.join("models", params['model']['filename'])
241+
model = torch.jit.load(model_path).cuda()
242+
243+
elif "tensorrt" in params['backend']:
244+
onnx_model_file = os.path.join("models", params['model']['onnx_file'])
245+
with open(onnx_model_file, 'rb') as onnx_model:
246+
print('Beginning ONNX file parsing')
247+
model = onnx_model.read()
248+
249+
return model
250+
251+
if __name__ == '__main__':
252+
253+
parser = argparse.ArgumentParser(description="Run inference on a model with random input values")
254+
parser.add_argument("--config", help="Load YAML based configuration file to run the inference. If this is used other params will be ignored")
255+
args = parser.parse_args()
256+
257+
# Load YAML params
258+
params = read_config(args.config)
259+
260+
print("Loading model: ", params['model']['filename'])
261+
262+
model = None
263+
264+
if "device" in params['runtime']:
265+
torch.cuda.set_device(params['runtime']['device'])
266+
267+
model = load_model(params)
268+
269+
cudnn.benchmark = True
270+
271+
# Create random input tensor of certain size
272+
torch.manual_seed(12345)
273+
274+
num_input = params['input']['num_of_input']
275+
for precision in params['runtime']['precision']:
276+
input_tensors = []
277+
num_input = params['input']['num_of_input']
278+
for i in range(num_input):
279+
inp_tensor = params['input']['input' + str(i)]
280+
input_tensors.append(torch.randint(0, 2, tuple(d for d in inp_tensor), dtype=precision_to_dtype(precision)).cuda())
281+
282+
if precision == "fp16" or precision == "half":
283+
#input_tensors = [x.half() for x in input_tensors]
284+
model = model.half()
285+
286+
run(model, input_tensors, params, precision)
287+
288+
print('Model Summary:')
289+
summary = pd.DataFrame(results)
290+
print(summary)

0 commit comments

Comments
 (0)