Skip to content

Commit 6eeba1c

Browse files
committed
feat(//py): [to_backend] adding device specification support for
to_backend Also fixes nested dictionary bug reported in #286 Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 72bf74b commit 6eeba1c

File tree

6 files changed

+52
-37
lines changed

6 files changed

+52
-37
lines changed

Diff for: py/trtorch/_compile_spec.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,6 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
147147
assert isinstance(compile_spec["strict_types"], bool)
148148
info.strict_types = compile_spec["strict_types"]
149149

150-
if "allow_gpu_fallback" in compile_spec:
151-
assert isinstance(compile_spec["allow_gpu_fallback"], bool)
152-
info.allow_gpu_fallback = compile_spec["allow_gpu_fallback"]
153-
154150
if "device" in compile_spec:
155151
info.device = _parse_device(compile_spec["device"])
156152

@@ -177,7 +173,7 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
177173
return info
178174

179175

180-
def TensorRTCompileSpec(compile_spec: Dict[str, Any]):
176+
def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.CompileSpec:
181177
"""
182178
Utility to create a formated spec dictionary for using the PyTorch TensorRT backend
183179
@@ -235,14 +231,13 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]):
235231
ir.set_max(i.max)
236232
backend_spec.append_input_range(ir)
237233

238-
for i in parsed_spec.device:
239-
ir = torch.classes.tensorrt.Device()
240-
ir.set_device_type(i.device_type)
241-
ir.set_gpu_id(i.gpu_id)
242-
ir.set_dla_core(i.dla_core)
243-
ir.set_allow_gpu_fallback(i.allow_gpu_fallback)
244-
backend_spec.set_device(ir)
234+
d = torch.classes.tensorrt.Device()
235+
d.set_device_type(int(parsed_spec.device.device_type))
236+
d.set_gpu_id(parsed_spec.device.gpu_id)
237+
d.set_dla_core(parsed_spec.device.dla_core)
238+
d.set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback)
245239

240+
backend_spec.set_device(d)
246241
backend_spec.set_op_precision(int(parsed_spec.op_precision))
247242
backend_spec.set_refit(parsed_spec.refit)
248243
backend_spec.set_debug(parsed_spec.debug)

Diff for: py/trtorch/csrc/register_tensorrt_classes.cpp

+18-6
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,34 @@
33
namespace trtorch {
44
namespace backend {
55
namespace {
6-
void RegisterTRTCompileSpec() {
6+
77
#define ADD_FIELD_GET_SET_REGISTRATION(registry, class_name, field_name) \
88
(registry).def("set_" #field_name, &class_name::set_##field_name); \
99
(registry).def("get_" #field_name, &class_name::get_##field_name);
1010

11+
void RegisterTRTCompileSpec() {
1112
static auto TRTORCH_UNUSED TRTInputRangeTSRegistration =
12-
torch::class_<trtorch::pyapi::InputRange>("tensorrt", "InputRange").def(torch::init<>());
13+
torch::class_<trtorch::pyapi::InputRange>("tensorrt", "InputRange").def(torch::init<>());
1314

1415
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, min);
1516
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, opt);
1617
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, max);
1718

19+
static auto TRTORCH_UNUSED TRTDeviceTSRegistration =
20+
torch::class_<trtorch::pyapi::Device>("tensorrt", "Device").def(torch::init<>());
21+
22+
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, device_type);
23+
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, gpu_id);
24+
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, dla_core);
25+
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback);
26+
27+
1828
static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration =
19-
torch::class_<trtorch::pyapi::CompileSpec>("tensorrt", "CompileSpec")
20-
.def(torch::init<>())
21-
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
22-
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);
29+
torch::class_<trtorch::pyapi::CompileSpec>("tensorrt", "CompileSpec")
30+
.def(torch::init<>())
31+
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
32+
.def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
33+
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);
2334

2435
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision);
2536
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit);
@@ -30,6 +41,7 @@ void RegisterTRTCompileSpec() {
3041
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, num_avg_timing_iters);
3142
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, workspace_size);
3243
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, max_batch_size);
44+
3345
}
3446

3547
struct TRTTSRegistrations {

Diff for: py/trtorch/csrc/tensorrt_backend.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue processed_mod, c10::
4646
auto method = mod.get_method(method_name);
4747
auto g = method.graph();
4848

49-
auto raw_spec = it->value().toGenericDict().at(it->key()).toCustomClass<trtorch::pyapi::CompileSpec>();
49+
auto raw_spec = it->value().toCustomClass<trtorch::pyapi::CompileSpec>();
5050
LOG_DEBUG(raw_spec->stringify());
5151
auto cfg = raw_spec->toInternalCompileSpec();
5252
auto convert_cfg = std::move(cfg.convert_info);

Diff for: py/trtorch/csrc/tensorrt_classes.h

+17-13
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@ namespace pyapi {
1717
return field_name; \
1818
}
1919

20+
// TODO: Make this error message more informative
21+
#define ADD_ENUM_GET_SET(field_name, type, max_val) \
22+
void set_##field_name(int64_t val) { \
23+
TRTORCH_CHECK(val < max_val, "Invalid enum value for field"); \
24+
field_name = static_cast<type>(val); \
25+
} \
26+
int64_t get_##field_name() { \
27+
return static_cast<int64_t>(field_name); \
28+
}
29+
2030
struct InputRange : torch::CustomClassHolder {
2131
std::vector<int64_t> min;
2232
std::vector<int64_t> opt;
@@ -59,7 +69,7 @@ struct Device : torch::CustomClassHolder {
5969
allow_gpu_fallback(false) // allow_gpu_fallback
6070
{}
6171

62-
ADD_FIELD_GET_SET(device_type, DeviceType);
72+
ADD_ENUM_GET_SET(device_type, DeviceType, 1);
6373
ADD_FIELD_GET_SET(gpu_id, int64_t);
6474
ADD_FIELD_GET_SET(dla_core, int64_t);
6575
ADD_FIELD_GET_SET(allow_gpu_fallback, bool);
@@ -77,28 +87,22 @@ enum class EngineCapability : int8_t {
7787
std::string to_str(EngineCapability value);
7888
nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value);
7989

80-
// TODO: Make this error message more informative
81-
#define ADD_ENUM_GET_SET(field_name, type, max_val) \
82-
void set_##field_name(int64_t val) { \
83-
TRTORCH_CHECK(val < max_val, "Invalid enum value for field"); \
84-
field_name = static_cast<type>(val); \
85-
} \
86-
int64_t get_##field_name() { \
87-
return static_cast<int64_t>(field_name); \
88-
}
89-
9090
struct CompileSpec : torch::CustomClassHolder {
9191
core::CompileSpec toInternalCompileSpec();
9292
std::string stringify();
9393
void appendInputRange(const c10::intrusive_ptr<InputRange>& ir) {
9494
input_ranges.push_back(*ir);
9595
}
9696

97-
ADD_ENUM_GET_SET(op_precision, DataType, 3);
97+
void setDeviceIntrusive(const c10::intrusive_ptr<Device>& d) {
98+
device = *d;
99+
}
100+
101+
ADD_ENUM_GET_SET(op_precision, DataType, 2);
98102
ADD_FIELD_GET_SET(refit, bool);
99103
ADD_FIELD_GET_SET(debug, bool);
100104
ADD_FIELD_GET_SET(strict_types, bool);
101-
ADD_ENUM_GET_SET(capability, EngineCapability, 3);
105+
ADD_ENUM_GET_SET(capability, EngineCapability, 2);
102106
ADD_FIELD_GET_SET(num_min_timing_iters, int64_t);
103107
ADD_FIELD_GET_SET(num_avg_timing_iters, int64_t);
104108
ADD_FIELD_GET_SET(workspace_size, int64_t);

Diff for: tests/py/BUILD

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ py_test(
1717
] + select({
1818
":aarch64_linux": [
1919
"test_api_dla.py"
20-
]
20+
],
21+
"//conditions:default" : []
2122
}),
2223
deps = [
2324
requirement("torchvision")

Diff for: tests/py/test_to_backend_api.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@ def setUp(self):
1919
"refit": False,
2020
"debug": False,
2121
"strict_types": False,
22-
"allow_gpu_fallback": True,
23-
"device_type": "gpu",
22+
"device": {
23+
"device_type": trtorch.DeviceType.GPU,
24+
"gpu_id": 0,
25+
"allow_gpu_fallback": True
26+
},
2427
"capability": trtorch.EngineCapability.default,
2528
"num_min_timing_iters": 2,
2629
"num_avg_timing_iters": 1,
@@ -29,14 +32,14 @@ def setUp(self):
2932
}
3033

3134
def test_to_backend_lowering(self):
32-
trt_mod = torch._C._jit_to_tensorrt(self.scripted_model._c, {"forward": self.spec})
35+
trt_mod = torch._C._jit_to_backend("tensorrt", self.scripted_model, self.spec)
3336
same = (trt_mod.forward(self.input) - self.scripted_model(self.input)).abs().max()
3437
self.assertTrue(same < 2e-3)
3538

3639

3740
def test_suite():
3841
suite = unittest.TestSuite()
39-
suite.addTest(TestToBackendLowering.parametrize(TestToBackendLowering, model=models.mobilenet_v2(pretrained=True)))
42+
suite.addTest(TestToBackendLowering.parametrize(TestToBackendLowering, model=models.resnet18(pretrained=True)))
4043

4144
return suite
4245

0 commit comments

Comments
 (0)