Skip to content

Commit 59113cf

Browse files
committed
feat(//py): Initial compiliant implementation of the to_backend api for
PyTorch Users can now use a direct PyTorch integration by just importing the trtorch package. The only difference between torch._C._jit_to_tensorrt and trtorch.compile is that you need to use the trtorch.TensorRTCompileSpec constructor to build a wrapper around your spec dictionary Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent b24c0d8 commit 59113cf

15 files changed

+573
-133
lines changed

Diff for: py/setup.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,12 @@ def run(self):
156156

157157
ext_modules = [
158158
cpp_extension.CUDAExtension('trtorch._C',
159-
['trtorch/csrc/trtorch_py.cpp'],
159+
[
160+
'trtorch/csrc/trtorch_py.cpp',
161+
'trtorch/csrc/tensorrt_backend.cpp',
162+
'trtorch/csrc/tensorrt_classes.cpp',
163+
'trtorch/csrc/register_tensorrt_classes.cpp',
164+
],
160165
library_dirs=[
161166
(dir_path + '/trtorch/lib/'),
162167
"/opt/conda/lib/python3.6/config-3.6m-x86_64-linux-gnu"
@@ -165,6 +170,7 @@ def run(self):
165170
"trtorch"
166171
],
167172
include_dirs=[
173+
dir_path + "trtorch/csrc",
168174
dir_path + "/../",
169175
dir_path + "/../bazel-TRTorch/external/tensorrt/include",
170176
],

Diff for: py/trtorch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from trtorch._version import __version__
1111
from trtorch._compiler import *
12+
from trtorch._compile_spec import TensorRTCompileSpec
1213
from trtorch._types import *
1314
from trtorch import logging
1415

Diff for: py/trtorch/_compile_spec.py

+84-9
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,21 @@ def _parse_op_precision(precision: Any) -> _types.dtype:
7373

7474
def _parse_device_type(device: Any) -> _types.DeviceType:
7575
if isinstance(device, torch.device):
76-
if torch.device.type == 'cuda':
76+
if device.type == 'cuda':
7777
return _types.DeviceType.gpu
7878
else:
79-
raise TypeError("Valid device choices are GPU (and DLA if on Jetson platforms) however got device type" + str(device.type))
80-
79+
ValueError("Got a device type other than GPU or DLA (type: " + str(device.type) + ")")
8180
elif isinstance(device, _types.DeviceType):
8281
return device
83-
82+
elif isinstance(device, str):
83+
if device == "gpu" or device == "GPU":
84+
return _types.DeviceType.gpu
85+
elif device == "dla" or device == "DLA":
86+
return _types.DeviceType.dla
87+
else:
88+
ValueError("Got a device type other than GPU or DLA (type: " + str(device) + ")")
8489
else:
85-
raise TypeError("Device specification must be of type torch.device or trtorch.DeviceType, but got: " + str(type(device)))
90+
raise TypeError("Device specification must be of type torch.device, string or trtorch.DeviceType, but got: " + str(type(device)))
8691

8792
def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
8893
info = trtorch._C.CompileSpec()
@@ -110,11 +115,11 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
110115
assert isinstance(compile_spec["allow_gpu_fallback"], bool)
111116
info.allow_gpu_fallback = compile_spec["allow_gpu_fallback"]
112117

113-
if "device" in compile_spec:
114-
info.device = _parse_device_type(compile_spec["device"])
118+
if "device_type" in compile_spec:
119+
info.device = _parse_device_type(compile_spec["device_type"])
115120

116121
if "capability" in compile_spec:
117-
assert isinstance(compile_spec["capability"], type.EngineCapability)
122+
assert isinstance(compile_spec["capability"], _types.EngineCapability)
118123
info.capability = compile_spec["capability"]
119124

120125
if "num_min_timing_iters" in compile_spec:
@@ -133,4 +138,74 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
133138
assert type(compile_spec["max_batch_size"]) is int
134139
info.max_batch_size = compile_spec["max_batch_size"]
135140

136-
return info
141+
return info
142+
143+
def TensorRTCompileSpec(compile_spec: Dict[str, Any]):
144+
"""
145+
Utility to create a formated spec dictionary for using the PyTorch TensorRT backend
146+
147+
Args:
148+
compile_spec (dict): Compilation settings including operating precision, target device, etc.
149+
One key is required which is ``input_shapes``, describing the input sizes or ranges for inputs
150+
to the graph. All other keys are optional. Entries for each method to be compiled.
151+
152+
.. code-block:: py
153+
154+
CompileSpec = {
155+
"forward" : trtorch.TensorRTCompileSpec({
156+
"input_shapes": [
157+
(1, 3, 224, 224), # Static input shape for input #1
158+
{
159+
"min": (1, 3, 224, 224),
160+
"opt": (1, 3, 512, 512),
161+
"max": (1, 3, 1024, 1024)
162+
} # Dynamic input shape for input #2
163+
],
164+
"op_precision": torch.half, # Operating precision set to FP16
165+
"refit": false, # enable refit
166+
"debug": false, # enable debuggable engine
167+
"strict_types": false, # kernels should strictly run in operating precision
168+
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
169+
"device": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
170+
"capability": trtorch.EngineCapability.DEFAULT, # Restrict kernel selection to safe gpu kernels or safe dla kernels
171+
"num_min_timing_iters": 2, # Number of minimization timing iterations used to select kernels
172+
"num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
173+
"workspace_size": 0, # Maximum size of workspace given to TensorRT
174+
"max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
175+
})
176+
}
177+
178+
Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using
179+
torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum
180+
to select device type.
181+
182+
Returns:
183+
torch.classes.tensorrt.CompileSpec: List of methods and formated spec objects to be provided to ``torch._C._jit_to_tensorrt``
184+
"""
185+
186+
parsed_spec = _parse_compile_spec(compile_spec)
187+
188+
backend_spec = torch.classes.tensorrt.CompileSpec()
189+
190+
for i in parsed_spec.input_ranges:
191+
ir = torch.classes.tensorrt.InputRange()
192+
ir.set_min(i.min)
193+
ir.set_opt(i.opt)
194+
ir.set_max(i.max)
195+
backend_spec.append_input_range(ir)
196+
197+
backend_spec.set_op_precision(int(parsed_spec.op_precision))
198+
backend_spec.set_refit(parsed_spec.refit)
199+
backend_spec.set_debug(parsed_spec.debug)
200+
backend_spec.set_refit(parsed_spec.refit)
201+
backend_spec.set_strict_types(parsed_spec.strict_types)
202+
backend_spec.set_allow_gpu_fallback(parsed_spec.allow_gpu_fallback)
203+
backend_spec.set_device(int(parsed_spec.device))
204+
backend_spec.set_capability(int(parsed_spec.capability))
205+
backend_spec.set_num_min_timing_iters(parsed_spec.num_min_timing_iters)
206+
backend_spec.set_num_avg_timing_iters(parsed_spec.num_avg_timing_iters)
207+
backend_spec.set_workspace_size(parsed_spec.workspace_size)
208+
backend_spec.set_max_batch_size(parsed_spec.max_batch_size)
209+
210+
return backend_spec
211+

Diff for: py/trtorch/_compiler.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri
3939
"debug": false, # enable debuggable engine
4040
"strict_types": false, # kernels should strictly run in operating precision
4141
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
42-
"device": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
42+
"device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
4343
"capability": trtorch.EngineCapability.DEFAULT, # Restrict kernel selection to safe gpu kernels or safe dla kernels
4444
"num_min_timing_iters": 2, # Number of minimization timing iterations used to select kernels
4545
"num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
@@ -91,7 +91,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
9191
"debug": false, # enable debuggable engine
9292
"strict_types": false, # kernels should strictly run in operating precision
9393
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
94-
"device": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
94+
"device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
9595
"capability": trtorch.EngineCapability.DEFAULT, # Restrict kernel selection to safe gpu kernels or safe dla kernels
9696
"num_min_timing_iters": 2, # Number of minimization timing iterations used to select kernels
9797
"num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels

Diff for: py/trtorch/csrc/register_tensorrt_classes.cpp

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include "tensorrt_classes.h"
2+
3+
namespace trtorch {
4+
namespace backend {
5+
namespace {
6+
void RegisterTRTCompileSpec() {
7+
#define ADD_FIELD_GET_SET_REGISTRATION(registry, class_name, field_name) \
8+
(registry).def("set_"#field_name, &class_name::set_##field_name); \
9+
(registry).def("get_"#field_name, &class_name::get_##field_name);
10+
11+
static auto TRTORCH_UNUSED TRTInputRangeTSRegistrtion = torch::class_<trtorch::pyapi::InputRange>("tensorrt", "InputRange")
12+
.def(torch::init<>());
13+
14+
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistrtion, trtorch::pyapi::InputRange, min);
15+
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistrtion, trtorch::pyapi::InputRange, opt);
16+
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistrtion, trtorch::pyapi::InputRange, max);
17+
18+
static auto TRTORCH_UNUSED TRTCompileSpecTSRegistrtion = torch::class_<trtorch::pyapi::CompileSpec>("tensorrt", "CompileSpec")
19+
.def(torch::init<>())
20+
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
21+
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);
22+
23+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, op_precision);
24+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, refit);
25+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, debug);
26+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, strict_types);
27+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, allow_gpu_fallback);
28+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, device);
29+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, capability);
30+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, num_min_timing_iters);
31+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, num_avg_timing_iters);
32+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, workspace_size);
33+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, max_batch_size);
34+
}
35+
36+
struct TRTTSRegistrations {
37+
TRTTSRegistrations() {
38+
RegisterTRTCompileSpec();
39+
}
40+
};
41+
42+
static TRTTSRegistrations register_trt_classes = TRTTSRegistrations();
43+
}
44+
} // namespace backend
45+
} // namespace trtorch
46+
47+

Diff for: py/trtorch/csrc/tensorrt_backend.cpp

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include "torch/csrc/jit/passes/lower_graph.h"
2+
3+
#include "tensorrt_backend.h"
4+
#include "tensorrt_classes.h"
5+
6+
#include "core/compiler.h"
7+
#include "core/lowering/lowering.h"
8+
#include "core/runtime/runtime.h"
9+
10+
namespace trtorch {
11+
namespace backend {
12+
13+
c10::IValue TensorRTBackend::preprocess(c10::IValue mod, c10::impl::GenericDict method_compile_spec) {
14+
auto mod_ = mod.toModule();
15+
LOG_DEBUG("Placing module in eval mode if not already");
16+
mod_.eval();
17+
mod_ = core::lowering::LowerModule(mod_);
18+
19+
auto spec =
20+
c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
21+
22+
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
23+
TRTORCH_CHECK(core::CheckMethodOperatorSupport(mod.toModule(), it->key()),
24+
"Method " << it->key() << "cannot be compiled by TRTorch");
25+
}
26+
27+
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
28+
const auto& method_name = it->key();
29+
auto method = mod_.get_method(method_name);
30+
auto graph = method.graph();
31+
core::lowering::LowerGraph(graph);
32+
}
33+
34+
return mod_._ivalue();
35+
}
36+
37+
c10::impl::GenericDict TensorRTBackend::compile(c10::IValue processed_mod, c10::impl::GenericDict method_compile_spec) {
38+
auto mod = processed_mod.toModule();
39+
auto spec =
40+
c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
41+
42+
auto handles = c10::impl::GenericDict(c10::StringType::get(), c10::getCustomClassType<c10::intrusive_ptr<core::runtime::TRTEngine>>());
43+
44+
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
45+
const auto& method_name = it->key();
46+
auto method = mod.get_method(method_name);
47+
auto g = method.graph();
48+
49+
auto raw_spec = it->value().toGenericDict().at(it->key()).toCustomClass<trtorch::pyapi::CompileSpec>();
50+
LOG_DEBUG(raw_spec->stringify());
51+
auto cfg = raw_spec->toInternalCompileSpec();
52+
auto convert_cfg = std::move(cfg.convert_info);
53+
auto graph_and_ivalues = torch::jit::LowerGraph(*g, mod._ivalue());
54+
55+
g = graph_and_ivalues.first;
56+
auto params = graph_and_ivalues.second;
57+
auto named_params = core::conversion::get_named_params(g->inputs(), params);
58+
59+
auto serialized_engine = core::conversion::ConvertBlockToEngine(g->block(), convert_cfg, named_params);
60+
auto engine_handle = c10::make_intrusive<core::runtime::TRTEngine>(it->key(), serialized_engine);
61+
handles.insert(method.name(), at::IValue(engine_handle));
62+
}
63+
64+
return c10::impl::toGenericDict(handles);
65+
}
66+
67+
68+
c10::impl::GenericList TensorRTBackend::execute(c10::IValue handle, c10::impl::GenericList inputs) {
69+
TRTORCH_ASSERT(inputs.size() > 0, "Trying to execute on empty list of arguments");
70+
auto engine = handle.toCustomClass<core::runtime::TRTEngine>();
71+
std::vector<at::Tensor> in_vec;
72+
for (size_t i = 0, e = inputs.size(); i < e; ++i) {
73+
c10::IValue val = inputs[i];
74+
TRTORCH_CHECK(val.isTensor(), "TensorRT currently only accepts Tensors as inputs");
75+
in_vec.push_back(val.toTensor());
76+
}
77+
auto outputs = core::runtime::execute_engine(in_vec, engine);
78+
return c10::impl::toList(c10::List<at::Tensor>(outputs));
79+
}
80+
81+
namespace {
82+
static auto reg = torch::jit::backend<TensorRTBackend>("tensorrt");
83+
}
84+
85+
} // namespace backend
86+
} // namespace trtorch

Diff for: py/trtorch/csrc/tensorrt_backend.h

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
#include "torch/csrc/jit/api/module.h"
3+
#include "torch/csrc/jit/backends/backend.h"
4+
5+
namespace trtorch {
6+
namespace backend {
7+
8+
class TensorRTBackend: public torch::jit::PyTorchBackendInterface {
9+
public:
10+
explicit TensorRTBackend() {}
11+
virtual ~TensorRTBackend() = default;
12+
13+
c10::IValue preprocess(c10::IValue mod, c10::impl::GenericDict method_compile_spec) override;
14+
c10::impl::GenericDict compile(c10::IValue processed_mod, c10::impl::GenericDict method_compile_spec) override;
15+
c10::impl::GenericList execute(c10::IValue handle, c10::impl::GenericList inputs) override;
16+
};
17+
18+
} // namespace backend
19+
} // namespace trtorch

0 commit comments

Comments
 (0)