Skip to content

Commit 59113cf

Browse files
committedOct 21, 2020
feat(//py): Initial compiliant implementation of the to_backend api for
PyTorch Users can now use a direct PyTorch integration by just importing the trtorch package. The only difference between torch._C._jit_to_tensorrt and trtorch.compile is that you need to use the trtorch.TensorRTCompileSpec constructor to build a wrapper around your spec dictionary Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent b24c0d8 commit 59113cf

15 files changed

+573
-133
lines changed
 

‎py/setup.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,12 @@ def run(self):
156156

157157
ext_modules = [
158158
cpp_extension.CUDAExtension('trtorch._C',
159-
['trtorch/csrc/trtorch_py.cpp'],
159+
[
160+
'trtorch/csrc/trtorch_py.cpp',
161+
'trtorch/csrc/tensorrt_backend.cpp',
162+
'trtorch/csrc/tensorrt_classes.cpp',
163+
'trtorch/csrc/register_tensorrt_classes.cpp',
164+
],
160165
library_dirs=[
161166
(dir_path + '/trtorch/lib/'),
162167
"/opt/conda/lib/python3.6/config-3.6m-x86_64-linux-gnu"
@@ -165,6 +170,7 @@ def run(self):
165170
"trtorch"
166171
],
167172
include_dirs=[
173+
dir_path + "trtorch/csrc",
168174
dir_path + "/../",
169175
dir_path + "/../bazel-TRTorch/external/tensorrt/include",
170176
],

‎py/trtorch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from trtorch._version import __version__
1111
from trtorch._compiler import *
12+
from trtorch._compile_spec import TensorRTCompileSpec
1213
from trtorch._types import *
1314
from trtorch import logging
1415

‎py/trtorch/_compile_spec.py

+84-9
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,21 @@ def _parse_op_precision(precision: Any) -> _types.dtype:
7373

7474
def _parse_device_type(device: Any) -> _types.DeviceType:
7575
if isinstance(device, torch.device):
76-
if torch.device.type == 'cuda':
76+
if device.type == 'cuda':
7777
return _types.DeviceType.gpu
7878
else:
79-
raise TypeError("Valid device choices are GPU (and DLA if on Jetson platforms) however got device type" + str(device.type))
80-
79+
ValueError("Got a device type other than GPU or DLA (type: " + str(device.type) + ")")
8180
elif isinstance(device, _types.DeviceType):
8281
return device
83-
82+
elif isinstance(device, str):
83+
if device == "gpu" or device == "GPU":
84+
return _types.DeviceType.gpu
85+
elif device == "dla" or device == "DLA":
86+
return _types.DeviceType.dla
87+
else:
88+
ValueError("Got a device type other than GPU or DLA (type: " + str(device) + ")")
8489
else:
85-
raise TypeError("Device specification must be of type torch.device or trtorch.DeviceType, but got: " + str(type(device)))
90+
raise TypeError("Device specification must be of type torch.device, string or trtorch.DeviceType, but got: " + str(type(device)))
8691

8792
def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
8893
info = trtorch._C.CompileSpec()
@@ -110,11 +115,11 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
110115
assert isinstance(compile_spec["allow_gpu_fallback"], bool)
111116
info.allow_gpu_fallback = compile_spec["allow_gpu_fallback"]
112117

113-
if "device" in compile_spec:
114-
info.device = _parse_device_type(compile_spec["device"])
118+
if "device_type" in compile_spec:
119+
info.device = _parse_device_type(compile_spec["device_type"])
115120

116121
if "capability" in compile_spec:
117-
assert isinstance(compile_spec["capability"], type.EngineCapability)
122+
assert isinstance(compile_spec["capability"], _types.EngineCapability)
118123
info.capability = compile_spec["capability"]
119124

120125
if "num_min_timing_iters" in compile_spec:
@@ -133,4 +138,74 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
133138
assert type(compile_spec["max_batch_size"]) is int
134139
info.max_batch_size = compile_spec["max_batch_size"]
135140

136-
return info
141+
return info
142+
143+
def TensorRTCompileSpec(compile_spec: Dict[str, Any]):
144+
"""
145+
Utility to create a formated spec dictionary for using the PyTorch TensorRT backend
146+
147+
Args:
148+
compile_spec (dict): Compilation settings including operating precision, target device, etc.
149+
One key is required which is ``input_shapes``, describing the input sizes or ranges for inputs
150+
to the graph. All other keys are optional. Entries for each method to be compiled.
151+
152+
.. code-block:: py
153+
154+
CompileSpec = {
155+
"forward" : trtorch.TensorRTCompileSpec({
156+
"input_shapes": [
157+
(1, 3, 224, 224), # Static input shape for input #1
158+
{
159+
"min": (1, 3, 224, 224),
160+
"opt": (1, 3, 512, 512),
161+
"max": (1, 3, 1024, 1024)
162+
} # Dynamic input shape for input #2
163+
],
164+
"op_precision": torch.half, # Operating precision set to FP16
165+
"refit": false, # enable refit
166+
"debug": false, # enable debuggable engine
167+
"strict_types": false, # kernels should strictly run in operating precision
168+
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
169+
"device": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
170+
"capability": trtorch.EngineCapability.DEFAULT, # Restrict kernel selection to safe gpu kernels or safe dla kernels
171+
"num_min_timing_iters": 2, # Number of minimization timing iterations used to select kernels
172+
"num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
173+
"workspace_size": 0, # Maximum size of workspace given to TensorRT
174+
"max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
175+
})
176+
}
177+
178+
Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using
179+
torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum
180+
to select device type.
181+
182+
Returns:
183+
torch.classes.tensorrt.CompileSpec: List of methods and formated spec objects to be provided to ``torch._C._jit_to_tensorrt``
184+
"""
185+
186+
parsed_spec = _parse_compile_spec(compile_spec)
187+
188+
backend_spec = torch.classes.tensorrt.CompileSpec()
189+
190+
for i in parsed_spec.input_ranges:
191+
ir = torch.classes.tensorrt.InputRange()
192+
ir.set_min(i.min)
193+
ir.set_opt(i.opt)
194+
ir.set_max(i.max)
195+
backend_spec.append_input_range(ir)
196+
197+
backend_spec.set_op_precision(int(parsed_spec.op_precision))
198+
backend_spec.set_refit(parsed_spec.refit)
199+
backend_spec.set_debug(parsed_spec.debug)
200+
backend_spec.set_refit(parsed_spec.refit)
201+
backend_spec.set_strict_types(parsed_spec.strict_types)
202+
backend_spec.set_allow_gpu_fallback(parsed_spec.allow_gpu_fallback)
203+
backend_spec.set_device(int(parsed_spec.device))
204+
backend_spec.set_capability(int(parsed_spec.capability))
205+
backend_spec.set_num_min_timing_iters(parsed_spec.num_min_timing_iters)
206+
backend_spec.set_num_avg_timing_iters(parsed_spec.num_avg_timing_iters)
207+
backend_spec.set_workspace_size(parsed_spec.workspace_size)
208+
backend_spec.set_max_batch_size(parsed_spec.max_batch_size)
209+
210+
return backend_spec
211+

‎py/trtorch/_compiler.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri
3939
"debug": false, # enable debuggable engine
4040
"strict_types": false, # kernels should strictly run in operating precision
4141
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
42-
"device": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
42+
"device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
4343
"capability": trtorch.EngineCapability.DEFAULT, # Restrict kernel selection to safe gpu kernels or safe dla kernels
4444
"num_min_timing_iters": 2, # Number of minimization timing iterations used to select kernels
4545
"num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
@@ -91,7 +91,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
9191
"debug": false, # enable debuggable engine
9292
"strict_types": false, # kernels should strictly run in operating precision
9393
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
94-
"device": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
94+
"device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
9595
"capability": trtorch.EngineCapability.DEFAULT, # Restrict kernel selection to safe gpu kernels or safe dla kernels
9696
"num_min_timing_iters": 2, # Number of minimization timing iterations used to select kernels
9797
"num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include "tensorrt_classes.h"
2+
3+
namespace trtorch {
4+
namespace backend {
5+
namespace {
6+
void RegisterTRTCompileSpec() {
7+
#define ADD_FIELD_GET_SET_REGISTRATION(registry, class_name, field_name) \
8+
(registry).def("set_"#field_name, &class_name::set_##field_name); \
9+
(registry).def("get_"#field_name, &class_name::get_##field_name);
10+
11+
static auto TRTORCH_UNUSED TRTInputRangeTSRegistrtion = torch::class_<trtorch::pyapi::InputRange>("tensorrt", "InputRange")
12+
.def(torch::init<>());
13+
14+
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistrtion, trtorch::pyapi::InputRange, min);
15+
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistrtion, trtorch::pyapi::InputRange, opt);
16+
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistrtion, trtorch::pyapi::InputRange, max);
17+
18+
static auto TRTORCH_UNUSED TRTCompileSpecTSRegistrtion = torch::class_<trtorch::pyapi::CompileSpec>("tensorrt", "CompileSpec")
19+
.def(torch::init<>())
20+
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
21+
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);
22+
23+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, op_precision);
24+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, refit);
25+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, debug);
26+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, strict_types);
27+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, allow_gpu_fallback);
28+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, device);
29+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, capability);
30+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, num_min_timing_iters);
31+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, num_avg_timing_iters);
32+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, workspace_size);
33+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, max_batch_size);
34+
}
35+
36+
struct TRTTSRegistrations {
37+
TRTTSRegistrations() {
38+
RegisterTRTCompileSpec();
39+
}
40+
};
41+
42+
static TRTTSRegistrations register_trt_classes = TRTTSRegistrations();
43+
}
44+
} // namespace backend
45+
} // namespace trtorch
46+
47+

‎py/trtorch/csrc/tensorrt_backend.cpp

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include "torch/csrc/jit/passes/lower_graph.h"
2+
3+
#include "tensorrt_backend.h"
4+
#include "tensorrt_classes.h"
5+
6+
#include "core/compiler.h"
7+
#include "core/lowering/lowering.h"
8+
#include "core/runtime/runtime.h"
9+
10+
namespace trtorch {
11+
namespace backend {
12+
13+
c10::IValue TensorRTBackend::preprocess(c10::IValue mod, c10::impl::GenericDict method_compile_spec) {
14+
auto mod_ = mod.toModule();
15+
LOG_DEBUG("Placing module in eval mode if not already");
16+
mod_.eval();
17+
mod_ = core::lowering::LowerModule(mod_);
18+
19+
auto spec =
20+
c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
21+
22+
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
23+
TRTORCH_CHECK(core::CheckMethodOperatorSupport(mod.toModule(), it->key()),
24+
"Method " << it->key() << "cannot be compiled by TRTorch");
25+
}
26+
27+
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
28+
const auto& method_name = it->key();
29+
auto method = mod_.get_method(method_name);
30+
auto graph = method.graph();
31+
core::lowering::LowerGraph(graph);
32+
}
33+
34+
return mod_._ivalue();
35+
}
36+
37+
c10::impl::GenericDict TensorRTBackend::compile(c10::IValue processed_mod, c10::impl::GenericDict method_compile_spec) {
38+
auto mod = processed_mod.toModule();
39+
auto spec =
40+
c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
41+
42+
auto handles = c10::impl::GenericDict(c10::StringType::get(), c10::getCustomClassType<c10::intrusive_ptr<core::runtime::TRTEngine>>());
43+
44+
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
45+
const auto& method_name = it->key();
46+
auto method = mod.get_method(method_name);
47+
auto g = method.graph();
48+
49+
auto raw_spec = it->value().toGenericDict().at(it->key()).toCustomClass<trtorch::pyapi::CompileSpec>();
50+
LOG_DEBUG(raw_spec->stringify());
51+
auto cfg = raw_spec->toInternalCompileSpec();
52+
auto convert_cfg = std::move(cfg.convert_info);
53+
auto graph_and_ivalues = torch::jit::LowerGraph(*g, mod._ivalue());
54+
55+
g = graph_and_ivalues.first;
56+
auto params = graph_and_ivalues.second;
57+
auto named_params = core::conversion::get_named_params(g->inputs(), params);
58+
59+
auto serialized_engine = core::conversion::ConvertBlockToEngine(g->block(), convert_cfg, named_params);
60+
auto engine_handle = c10::make_intrusive<core::runtime::TRTEngine>(it->key(), serialized_engine);
61+
handles.insert(method.name(), at::IValue(engine_handle));
62+
}
63+
64+
return c10::impl::toGenericDict(handles);
65+
}
66+
67+
68+
c10::impl::GenericList TensorRTBackend::execute(c10::IValue handle, c10::impl::GenericList inputs) {
69+
TRTORCH_ASSERT(inputs.size() > 0, "Trying to execute on empty list of arguments");
70+
auto engine = handle.toCustomClass<core::runtime::TRTEngine>();
71+
std::vector<at::Tensor> in_vec;
72+
for (size_t i = 0, e = inputs.size(); i < e; ++i) {
73+
c10::IValue val = inputs[i];
74+
TRTORCH_CHECK(val.isTensor(), "TensorRT currently only accepts Tensors as inputs");
75+
in_vec.push_back(val.toTensor());
76+
}
77+
auto outputs = core::runtime::execute_engine(in_vec, engine);
78+
return c10::impl::toList(c10::List<at::Tensor>(outputs));
79+
}
80+
81+
namespace {
82+
static auto reg = torch::jit::backend<TensorRTBackend>("tensorrt");
83+
}
84+
85+
} // namespace backend
86+
} // namespace trtorch

‎py/trtorch/csrc/tensorrt_backend.h

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
#include "torch/csrc/jit/api/module.h"
3+
#include "torch/csrc/jit/backends/backend.h"
4+
5+
namespace trtorch {
6+
namespace backend {
7+
8+
class TensorRTBackend: public torch::jit::PyTorchBackendInterface {
9+
public:
10+
explicit TensorRTBackend() {}
11+
virtual ~TensorRTBackend() = default;
12+
13+
c10::IValue preprocess(c10::IValue mod, c10::impl::GenericDict method_compile_spec) override;
14+
c10::impl::GenericDict compile(c10::IValue processed_mod, c10::impl::GenericDict method_compile_spec) override;
15+
c10::impl::GenericList execute(c10::IValue handle, c10::impl::GenericList inputs) override;
16+
};
17+
18+
} // namespace backend
19+
} // namespace trtorch

‎py/trtorch/csrc/tensorrt_classes.cpp

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
2+
#include "tensorrt_classes.h"
3+
4+
namespace trtorch {
5+
namespace pyapi {
6+
7+
std::string to_str(InputRange& value) {
8+
auto vec_to_str = [](std::vector<int64_t> shape) -> std::string {
9+
std::stringstream ss;
10+
ss << '[';
11+
for(auto i : shape) {
12+
ss << i << ',';
13+
}
14+
ss << ']';
15+
return ss.str();
16+
};
17+
18+
std::stringstream ss;
19+
ss << " {" << std::endl;
20+
ss << " min: " << vec_to_str(value.min) << ',' << std::endl;
21+
ss << " opt: " << vec_to_str(value.opt) << ',' << std::endl;
22+
ss << " max: " << vec_to_str(value.max) << ',' << std::endl;
23+
ss << " }" << std::endl;
24+
return ss.str();
25+
}
26+
27+
std::string to_str(DataType value) {
28+
switch (value) {
29+
case DataType::kHalf:
30+
return "Half";
31+
case DataType::kChar:
32+
return "Int8";
33+
case DataType::kFloat:
34+
default:
35+
return "Float";
36+
}
37+
}
38+
39+
nvinfer1::DataType toTRTDataType(DataType value) {
40+
switch (value) {
41+
case DataType::kChar:
42+
return nvinfer1::DataType::kINT8;
43+
case DataType::kHalf:
44+
return nvinfer1::DataType::kHALF;
45+
case DataType::kFloat:
46+
default:
47+
return nvinfer1::DataType::kFLOAT;
48+
}
49+
}
50+
51+
std::string to_str(DeviceType value) {
52+
switch (value) {
53+
case DeviceType::kDLA:
54+
return "DLA";
55+
case DeviceType::kGPU:
56+
default:
57+
return "GPU";
58+
}
59+
}
60+
61+
nvinfer1::DeviceType toTRTDeviceType(DeviceType value) {
62+
switch (value) {
63+
case DeviceType::kDLA:
64+
return nvinfer1::DeviceType::kDLA;
65+
case DeviceType::kGPU:
66+
default:
67+
return nvinfer1::DeviceType::kGPU;
68+
}
69+
}
70+
71+
std::string to_str(EngineCapability value) {
72+
switch (value) {
73+
case EngineCapability::kSAFE_GPU:
74+
return "Safe GPU";
75+
case EngineCapability::kSAFE_DLA:
76+
return "Safe DLA";
77+
case EngineCapability::kDEFAULT:
78+
default:
79+
return "Default";
80+
}
81+
}
82+
83+
nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) {
84+
switch (value) {
85+
case EngineCapability::kSAFE_DLA:
86+
return nvinfer1::EngineCapability::kSAFE_DLA;
87+
case EngineCapability::kSAFE_GPU:
88+
return nvinfer1::EngineCapability::kSAFE_GPU;
89+
case EngineCapability::kDEFAULT:
90+
default:
91+
return nvinfer1::EngineCapability::kDEFAULT;
92+
}
93+
}
94+
95+
core::CompileSpec CompileSpec::toInternalCompileSpec() {
96+
std::vector<core::conversion::InputRange> internal_input_ranges;
97+
for (auto i : input_ranges) {
98+
internal_input_ranges.push_back(i.toInternalInputRange());
99+
}
100+
auto info = core::CompileSpec(internal_input_ranges);
101+
info.convert_info.engine_settings.op_precision = toTRTDataType(op_precision);
102+
info.convert_info.engine_settings.refit = refit;
103+
info.convert_info.engine_settings.debug = debug;
104+
info.convert_info.engine_settings.strict_types = strict_types;
105+
info.convert_info.engine_settings.allow_gpu_fallback = allow_gpu_fallback;
106+
info.convert_info.engine_settings.device = toTRTDeviceType(device);
107+
info.convert_info.engine_settings.capability = toTRTEngineCapability(capability);
108+
TRTORCH_CHECK(num_min_timing_iters >= 0, "num_min_timing_iters must be 0 or greater");
109+
info.convert_info.engine_settings.num_min_timing_iters = num_min_timing_iters;
110+
TRTORCH_CHECK(num_avg_timing_iters >= 0, "num_avg_timing_iters must be 0 or greater");
111+
info.convert_info.engine_settings.num_avg_timing_iters = num_avg_timing_iters;
112+
TRTORCH_CHECK(workspace_size >= 0, "workspace_size must be 0 or greater");
113+
info.convert_info.engine_settings.workspace_size = workspace_size;
114+
TRTORCH_CHECK(max_batch_size >= 0, "max_batch_size must be 0 or greater");
115+
info.convert_info.engine_settings.max_batch_size = max_batch_size;
116+
return info;
117+
}
118+
119+
std::string CompileSpec::stringify() {
120+
std::stringstream ss;
121+
ss << "TensorRT Compile Spec: {" << std::endl;
122+
ss << " \"Input Shapes\": [" << std::endl;
123+
for (auto i : input_ranges) {
124+
ss << to_str(i);
125+
}
126+
ss << " ]" << std::endl;
127+
ss << " \"Op Precision\": " << to_str(op_precision) << std::endl;
128+
ss << " \"Refit\": " << refit << std::endl;
129+
ss << " \"Debug\": " << debug << std::endl;
130+
ss << " \"Strict Types\": " << strict_types << std::endl;
131+
ss << " \"Allow GPU Fallback\": " << allow_gpu_fallback << std::endl;
132+
ss << " \"Device\": " << to_str(capability) << std::endl;
133+
ss << " \"Engine Capability\": " << to_str(capability) << std::endl;
134+
ss << " \"Num Min Timing Iters\": " << num_min_timing_iters << std::endl;
135+
ss << " \"Num Avg Timing Iters\": " << num_avg_timing_iters << std::endl;
136+
ss << " \"Workspace Size\": " << workspace_size << std::endl;
137+
ss << " \"Max Batch Size\": " << max_batch_size << std::endl;
138+
ss << "}";
139+
return ss.str();
140+
}
141+
142+
} // namespace pyapi
143+
} // namespace trtorch

‎py/trtorch/csrc/tensorrt_classes.h

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#pragma once
2+
3+
#include "core/compiler.h"
4+
#include "core/conversion/conversion.h"
5+
#include "torch/torch.h"
6+
#include "torch/script.h"
7+
#include "torch/custom_class.h"
8+
9+
namespace trtorch {
10+
namespace pyapi {
11+
12+
#define ADD_FIELD_GET_SET(field_name, type) \
13+
void set_##field_name(type val) {field_name = val;} \
14+
type get_##field_name() {return field_name;}
15+
16+
struct InputRange : torch::CustomClassHolder {
17+
std::vector<int64_t> min;
18+
std::vector<int64_t> opt;
19+
std::vector<int64_t> max;
20+
21+
core::conversion::InputRange toInternalInputRange() {
22+
return core::conversion::InputRange(min, opt, max);
23+
}
24+
25+
ADD_FIELD_GET_SET(min, std::vector<int64_t>);
26+
ADD_FIELD_GET_SET(opt, std::vector<int64_t>);
27+
ADD_FIELD_GET_SET(max, std::vector<int64_t>);
28+
};
29+
30+
std::string to_str(InputRange& value);
31+
32+
33+
enum class DataType : int8_t {
34+
kFloat,
35+
kHalf,
36+
kChar,
37+
};
38+
39+
std::string to_str(DataType value);
40+
nvinfer1::DataType toTRTDataType(DataType value);
41+
42+
enum DeviceType : int8_t {
43+
kGPU,
44+
kDLA,
45+
};
46+
47+
std::string to_str(DeviceType value);
48+
nvinfer1::DeviceType toTRTDeviceType(DeviceType value);
49+
50+
enum class EngineCapability : int8_t {
51+
kDEFAULT,
52+
kSAFE_GPU,
53+
kSAFE_DLA,
54+
};
55+
56+
std::string to_str(EngineCapability value);
57+
nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value);
58+
59+
// TODO: Make this error message more informative
60+
#define ADD_ENUM_GET_SET(field_name, type, max_val) \
61+
void set_##field_name(int64_t val) { \
62+
TRTORCH_CHECK(val < max_val, "Invalid enum value for field"); \
63+
field_name = static_cast<type>(val); \
64+
} \
65+
int64_t get_##field_name() {return static_cast<int64_t>(field_name);}
66+
67+
struct CompileSpec : torch::CustomClassHolder {
68+
core::CompileSpec toInternalCompileSpec();
69+
std::string stringify();
70+
void appendInputRange(const c10::intrusive_ptr<InputRange>& ir) {
71+
input_ranges.push_back(*ir);
72+
}
73+
74+
ADD_ENUM_GET_SET(op_precision, DataType, 3);
75+
ADD_FIELD_GET_SET(refit, bool);
76+
ADD_FIELD_GET_SET(debug, bool);
77+
ADD_FIELD_GET_SET(strict_types, bool);
78+
ADD_FIELD_GET_SET(allow_gpu_fallback, bool);
79+
ADD_ENUM_GET_SET(device, DeviceType, 2);
80+
ADD_ENUM_GET_SET(capability, EngineCapability, 3);
81+
ADD_FIELD_GET_SET(num_min_timing_iters, int64_t);
82+
ADD_FIELD_GET_SET(num_avg_timing_iters, int64_t);
83+
ADD_FIELD_GET_SET(workspace_size, int64_t);
84+
ADD_FIELD_GET_SET(max_batch_size, int64_t);
85+
86+
std::vector<InputRange> input_ranges;
87+
DataType op_precision = DataType::kFloat;
88+
bool refit = false;
89+
bool debug = false;
90+
bool strict_types = false;
91+
bool allow_gpu_fallback = true;
92+
DeviceType device = DeviceType::kGPU;
93+
EngineCapability capability = EngineCapability::kDEFAULT;
94+
int64_t num_min_timing_iters = 2;
95+
int64_t num_avg_timing_iters = 1;
96+
int64_t workspace_size = 0;
97+
int64_t max_batch_size = 0;
98+
};
99+
100+
} // namespace pyapi
101+
} // namespace trtorch

‎py/trtorch/csrc/trtorch_py.cpp

+3-103
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
#include "pybind11/pybind11.h"
22
#include "pybind11/stl.h"
3-
//TODO: Remove when we have access to PyTorch to_backend autoregistration
4-
#include "core/backend.h"
3+
4+
#include "tensorrt_classes.h"
55
#include "core/compiler.h"
66
#include "core/conversion/conversion.h"
77
#include "torch/torch.h"
88
#include "torch/script.h"
9+
#include "torch/custom_class.h"
910
#include "torch/csrc/jit/python/pybind_utils.h"
1011
#include "Python.h"
1112

@@ -14,103 +15,6 @@ namespace py = pybind11;
1415
namespace trtorch {
1516
namespace pyapi {
1617

17-
struct InputRange {
18-
std::vector<int64_t> min;
19-
std::vector<int64_t> opt;
20-
std::vector<int64_t> max;
21-
22-
core::conversion::InputRange toInternalInputRange() {
23-
return core::conversion::InputRange(min, opt, max);
24-
}
25-
};
26-
27-
enum class DataType : int8_t {
28-
kFloat,
29-
kHalf,
30-
kChar,
31-
};
32-
33-
nvinfer1::DataType toTRTDataType(DataType value) {
34-
switch (value) {
35-
case DataType::kChar:
36-
return nvinfer1::DataType::kINT8;
37-
case DataType::kHalf:
38-
return nvinfer1::DataType::kHALF;
39-
case DataType::kFloat:
40-
default:
41-
return nvinfer1::DataType::kFLOAT;
42-
}
43-
}
44-
45-
enum DeviceType : int8_t {
46-
kGPU,
47-
kDLA,
48-
};
49-
50-
nvinfer1::DeviceType toTRTDeviceType(DeviceType value) {
51-
switch (value) {
52-
case DeviceType::kDLA:
53-
return nvinfer1::DeviceType::kDLA;
54-
case DeviceType::kGPU:
55-
default:
56-
return nvinfer1::DeviceType::kGPU;
57-
}
58-
}
59-
60-
enum class EngineCapability : int8_t {
61-
kDEFAULT,
62-
kSAFE_GPU,
63-
kSAFE_DLA,
64-
};
65-
66-
nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) {
67-
switch (value) {
68-
case EngineCapability::kSAFE_DLA:
69-
return nvinfer1::EngineCapability::kSAFE_DLA;
70-
case EngineCapability::kSAFE_GPU:
71-
return nvinfer1::EngineCapability::kSAFE_GPU;
72-
case EngineCapability::kDEFAULT:
73-
default:
74-
return nvinfer1::EngineCapability::kDEFAULT;
75-
}
76-
}
77-
78-
struct CompileSpec {
79-
80-
core::CompileSpec toInternalCompileSpec() {
81-
for (auto i : input_ranges) {
82-
internal_input_ranges.push_back(i.toInternalInputRange());
83-
}
84-
auto info = core::CompileSpec(internal_input_ranges);
85-
info.convert_info.engine_settings.op_precision = toTRTDataType(op_precision);
86-
info.convert_info.engine_settings.refit = refit;
87-
info.convert_info.engine_settings.debug = debug;
88-
info.convert_info.engine_settings.strict_types = strict_types;
89-
info.convert_info.engine_settings.allow_gpu_fallback = allow_gpu_fallback;
90-
info.convert_info.engine_settings.device = toTRTDeviceType(device);
91-
info.convert_info.engine_settings.capability = toTRTEngineCapability(capability);
92-
info.convert_info.engine_settings.num_min_timing_iters = num_min_timing_iters;
93-
info.convert_info.engine_settings.num_avg_timing_iters = num_avg_timing_iters;
94-
info.convert_info.engine_settings.workspace_size = workspace_size;
95-
info.convert_info.engine_settings.max_batch_size = max_batch_size;
96-
return info;
97-
}
98-
99-
std::vector<InputRange> input_ranges;
100-
std::vector<core::conversion::InputRange> internal_input_ranges;
101-
DataType op_precision = DataType::kFloat;
102-
bool refit = false;
103-
bool debug = false;
104-
bool strict_types = false;
105-
bool allow_gpu_fallback = true;
106-
DeviceType device = DeviceType::kGPU;
107-
EngineCapability capability = EngineCapability::kDEFAULT;
108-
uint64_t num_min_timing_iters = 2;
109-
uint64_t num_avg_timing_iters = 1;
110-
uint64_t workspace_size = 0;
111-
uint64_t max_batch_size = 0;
112-
};
113-
11418
torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec& info) {
11519
py::gil_scoped_acquire gil;
11620
auto trt_mod = core::CompileGraph(mod, info.toInternalCompileSpec());
@@ -227,11 +131,7 @@ PYBIND11_MODULE(_C, m) {
227131
.value("INFO", core::util::logging::LogLevel::kINFO)
228132
.value("DEBUG", core::util::logging::LogLevel::kDEBUG)
229133
.export_values();
230-
231-
//TODO: Remove when we have access to PyTorch autoregistration
232-
//m.def("to_tensorrt", backend::GetTensorRTBackend().generateToBackendFn());
233134
}
234135

235-
236136
} // namespace pyapi
237137
} // namespace trtorch

‎tests/BUILD

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ test_suite(
1717
test_suite(
1818
name = "python_api_tests",
1919
tests = [
20-
"//tests/py:test_api"
20+
"//tests/py:test_api",
21+
"//tests/py:test_to_backend_api"
2122
]
2223
)

‎tests/py/BUILD

+14-2
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,21 @@ load("@py_test_deps//:requirements.bzl", "requirement")
55
py_test(
66
name = "test_api",
77
srcs = [
8-
"test_api.py"
8+
"test_api.py",
9+
"model_test_case.py"
910
],
1011
deps = [
1112
requirement("torchvision")
1213
]
13-
)
14+
)
15+
16+
py_test(
17+
name = "test_to_backend_api",
18+
srcs = [
19+
"test_to_backend_api.py",
20+
"model_test_case.py"
21+
],
22+
deps = [
23+
requirement("torchvision")
24+
]
25+
)

‎tests/py/model_test_case.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import unittest
2+
import trtorch
3+
import torch
4+
import torchvision.models as models
5+
6+
class ModelTestCase(unittest.TestCase):
7+
def __init__(self, methodName='runTest', model=None):
8+
super(ModelTestCase, self).__init__(methodName)
9+
self.model = model
10+
self.model.eval().to("cuda")
11+
12+
@staticmethod
13+
def parametrize(testcase_class, model=None):
14+
testloader = unittest.TestLoader()
15+
testnames = testloader.getTestCaseNames(testcase_class)
16+
suite = unittest.TestSuite()
17+
for name in testnames:
18+
suite.addTest(testcase_class(name, model=model))
19+
return suite

‎tests/py/test_api.py

+1-15
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,7 @@
33
import torch
44
import torchvision.models as models
55

6-
7-
class ModelTestCase(unittest.TestCase):
8-
def __init__(self, methodName='runTest', model=None):
9-
super(ModelTestCase, self).__init__(methodName)
10-
self.model = model
11-
self.model.eval().to("cuda")
12-
13-
@staticmethod
14-
def parametrize(testcase_class, model=None):
15-
testloader = unittest.TestLoader()
16-
testnames = testloader.getTestCaseNames(testcase_class)
17-
suite = unittest.TestSuite()
18-
for name in testnames:
19-
suite.addTest(testcase_class(name, model=model))
20-
return suite
6+
from model_test_case import ModelTestCase
217

228
class TestCompile(ModelTestCase):
239
def setUp(self):

‎tests/py/test_to_backend_api.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import unittest
2+
import trtorch
3+
import torch
4+
import torchvision.models as models
5+
6+
from model_test_case import ModelTestCase
7+
8+
class TestToBackendLowering(ModelTestCase):
9+
def setUp(self):
10+
self.input = torch.randn((1, 3, 300, 300)).to("cuda")
11+
self.scripted_model = torch.jit.script(self.model)
12+
self.spec = {
13+
"forward": trtorch.TensorRTCompileSpec({
14+
"input_shapes": [[1, 3, 300, 300]],
15+
"op_precision": torch.float,
16+
"refit": False,
17+
"debug": False,
18+
"strict_types": False,
19+
"allow_gpu_fallback": True,
20+
"device_type": "gpu",
21+
"capability": trtorch.EngineCapability.default,
22+
"num_min_timing_iters": 2,
23+
"num_avg_timing_iters": 1,
24+
"max_batch_size": 0,
25+
})
26+
}
27+
28+
def test_to_backend_lowering(self):
29+
trt_mod = torch._C._jit_to_tensorrt(self.scripted_model._c, {"forward": self.spec})
30+
same = (trt_mod.forward(self.input) - self.scripted_model(self.input)).abs().max()
31+
self.assertTrue(same < 2e-3)
32+
33+
def test_suite():
34+
suite = unittest.TestSuite()
35+
suite.addTest(TestToBackendLowering.parametrize(TestToBackendLowering, model=models.mobilenet_v2(pretrained=True)))
36+
37+
return suite
38+
39+
suite = test_suite()
40+
41+
runner = unittest.TextTestRunner()
42+
result = runner.run(suite)
43+
44+
exit(int(not result.wasSuccessful()))

0 commit comments

Comments
 (0)
Please sign in to comment.