Skip to content

Commit 17e0e8a

Browse files
committed
feat(//py)!: Porting forward the API to use kwargs
BREAKING CHANGE: This changes the API for compile settings from a dictionary of settings to a set of kwargs for the various compilation functions. This will break existing code. However there is simple guidance to port forward your code: Given a dict of valid TRTorch CompileSpec settings ```py spec = { "inputs": ... ... } ``` You can use this same dict with the new APIs by changing your code from: ```py trtorch.compile(mod, spec) ``` to: ```py trtorch.compile(mod, **spec) ``` which will unpack the dictionary as arguments to the function Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 4d95b04 commit 17e0e8a

10 files changed

+208
-118
lines changed

Diff for: docsrc/py_api/trtorch.rst

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
.. _trtorch_py:
22

3+
.. automodule trtorch
4+
:undoc-members:
5+
36
trtorch
47
===============
58

Diff for: py/trtorch/_compile_spec.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _parse_op_precision(precision: Any) -> _types.dtype:
6464
raise TypeError("Provided an unsupported dtype as operating precision (support: int8, half, float), got: " +
6565
str(precision))
6666

67-
elif isinstance(precision, _types.DataTypes):
67+
elif isinstance(precision, _types.dtype):
6868
return precision
6969

7070
else:
@@ -170,6 +170,8 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
170170
inputs = [trtorch.Input._from_tensor(i) if isinstance(i, torch.Tensor) else i for i in compile_spec["inputs"]]
171171
info.inputs = [i._to_internal() for i in inputs]
172172

173+
assert (len(info.inputs) > 0), "Require at least one input definition to compile model"
174+
173175
if "op_precision" in compile_spec and "enabled_precisions" in compile_spec:
174176
raise KeyError(
175177
"Found both key \"op_precision\", and \"enabled_precisions\" in compile spec, please port forward to using only \"enabled_precisions\""

Diff for: py/trtorch/_compiler.py

+155-93
Large diffs are not rendered by default.

Diff for: py/trtorch/csrc/tensorrt_classes.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ nvinfer1::DataType toTRTDataType(DataType value) {
3838
}
3939
}
4040

41+
Device::Device(const core::runtime::CudaDevice& internal_dev) {
42+
device_type = DeviceType::kGPU;
43+
gpu_id = internal_dev.id;
44+
dla_core = -1;
45+
allow_gpu_fallback = false;
46+
}
47+
4148
nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value) {
4249
switch (value) {
4350
case TensorFormat::kChannelLast:

Diff for: py/trtorch/csrc/tensorrt_classes.h

+2
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ struct Device : torch::CustomClassHolder {
7474
allow_gpu_fallback(false) // allow_gpu_fallback
7575
{}
7676

77+
Device(const core::runtime::CudaDevice& internal_dev);
78+
7779
ADD_ENUM_GET_SET(device_type, DeviceType, static_cast<int64_t>(DeviceType::kDLA));
7880
ADD_FIELD_GET_SET(gpu_id, int64_t);
7981
ADD_FIELD_GET_SET(dla_core, int64_t);

Diff for: py/trtorch/csrc/trtorch_py.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ void set_device(const int device_id) {
103103
core::set_device(device_id);
104104
}
105105

106+
Device get_current_device() {
107+
return Device(core::runtime::get_current_device());
108+
}
109+
106110
torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec& info) {
107111
py::gil_scoped_acquire gil;
108112
auto trt_mod = core::CompileGraph(mod, info.toInternalCompileSpec());
@@ -315,6 +319,8 @@ PYBIND11_MODULE(_C, m) {
315319
m.def("_set_is_colored_output_on", &logging::set_is_colored_output_on, "Set if the logging output should be colored");
316320
m.def("_log", &logging::log, "Add a message to the logger");
317321
m.def("set_device", &trtorch::pyapi::set_device, "Set CUDA device id");
322+
m.def("_get_current_device", &trtorch::pyapi::get_current_device, "Get the current active CUDA device");
323+
318324

319325
py::enum_<core::util::logging::LogLevel>(m, "LogLevel", py::arithmetic())
320326
.value("INTERNAL_ERROR", core::util::logging::LogLevel::kINTERNAL_ERROR)

Diff for: tests/py/test_api.py

+25-17
Original file line numberDiff line numberDiff line change
@@ -23,44 +23,52 @@ def test_compile_traced(self):
2323
"enabled_precisions": {torch.float}
2424
}
2525

26-
trt_mod = trtorch.compile(self.traced_model, compile_spec)
26+
trt_mod = trtorch.compile(self.traced_model, **compile_spec)
2727
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
2828
self.assertTrue(same < 2e-2)
2929

3030
def test_compile_script(self):
31+
trt_mod = trtorch.compile(self.scripted_model, inputs=[self.input], device=trtorch.Device(gpu_id=0), enabled_precisions={torch.float})
32+
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
33+
self.assertTrue(same < 2e-2)
34+
35+
def test_from_torch_tensor(self):
3136
compile_spec = {
32-
"inputs": [trtorch.Input(shape=self.input.shape)],
37+
"inputs": [self.input],
3338
"device": {
3439
"device_type": trtorch.DeviceType.GPU,
3540
"gpu_id": 0,
3641
},
3742
"enabled_precisions": {torch.float}
3843
}
3944

40-
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
45+
trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
4146
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
4247
self.assertTrue(same < 2e-2)
4348

44-
def test_from_torch_tensor(self):
49+
def test_device(self):
50+
compile_spec = {"inputs": [self.input], "device": trtorch.Device("gpu:0"), "enabled_precisions": {torch.float}}
51+
52+
trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
53+
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
54+
self.assertTrue(same < 2e-2)
55+
56+
57+
def test_compile_script_from_dict(self):
4558
compile_spec = {
46-
"inputs": [self.input],
59+
"inputs": [trtorch.Input(shape=self.input.shape)],
4760
"device": {
4861
"device_type": trtorch.DeviceType.GPU,
4962
"gpu_id": 0,
5063
},
5164
"enabled_precisions": {torch.float}
5265
}
5366

54-
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
55-
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
67+
trt_mod = trtorch.compile(self.traced_model, **compile_spec)
68+
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
5669
self.assertTrue(same < 2e-2)
5770

58-
def test_device(self):
59-
compile_spec = {"inputs": [self.input], "device": trtorch.Device("gpu:0"), "enabled_precisions": {torch.float}}
6071

61-
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
62-
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
63-
self.assertTrue(same < 2e-2)
6472

6573

6674
class TestCompileHalf(ModelTestCase):
@@ -80,7 +88,7 @@ def test_compile_script_half(self):
8088
"enabled_precisions": {torch.half}
8189
}
8290

83-
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
91+
trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
8492
same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
8593
trtorch.logging.log(trtorch.logging.Level.Debug, "Max diff: " + str(same))
8694
self.assertTrue(same < 3e-2)
@@ -103,7 +111,7 @@ def test_compile_script_half_by_default(self):
103111
"enabled_precisions": {torch.float, torch.half}
104112
}
105113

106-
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
114+
trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
107115
same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
108116
trtorch.logging.log(trtorch.logging.Level.Debug, "Max diff: " + str(same))
109117
self.assertTrue(same < 3e-2)
@@ -132,7 +140,7 @@ def test_compile_script(self):
132140
}
133141
}
134142

135-
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
143+
trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
136144
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
137145
self.assertTrue(same < 2e-3)
138146

@@ -160,7 +168,7 @@ def test_compile_script(self):
160168
}
161169
}
162170

163-
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
171+
trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
164172
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
165173
self.assertTrue(same < 2e-3)
166174

@@ -183,7 +191,7 @@ def test_pt_to_trt_to_pt(self):
183191
}
184192
}
185193

186-
trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", compile_spec)
194+
trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", **compile_spec)
187195
trt_mod = trtorch.embed_engine_in_new_module(trt_engine, trtorch.Device("cuda:0"))
188196
same = (trt_mod(self.input) - self.ts_model(self.input)).abs().max()
189197
self.assertTrue(same < 2e-3)

Diff for: tests/py/test_api_dla.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_compile_traced(self):
4040
"enabled_precisions": {torch.half}
4141
}
4242

43-
trt_mod = trtorch.compile(self.traced_model, compile_spec)
43+
trt_mod = trtorch.compile(self.traced_model, **compile_spec)
4444
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
4545
self.assertTrue(same < 2e-2)
4646

@@ -56,7 +56,7 @@ def test_compile_script(self):
5656
"enabled_precisions": {torch.half}
5757
}
5858

59-
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
59+
trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
6060
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
6161
self.assertTrue(same < 2e-2)
6262

Diff for: tests/py/test_multi_gpu.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_compile_traced(self):
3232
}
3333
}
3434

35-
trt_mod = trtorch.compile(self.traced_model, compile_spec)
35+
trt_mod = trtorch.compile(self.traced_model, **compile_spec)
3636
trtorch.set_device(self.target_gpu)
3737
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
3838
trtorch.set_device(0)
@@ -51,7 +51,7 @@ def test_compile_script(self):
5151
}
5252
}
5353

54-
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
54+
trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
5555
trtorch.set_device(self.target_gpu)
5656
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
5757
trtorch.set_device(0)
@@ -84,7 +84,7 @@ def test_compile_traced(self):
8484
}
8585
}
8686

87-
trt_mod = trtorch.compile(self.traced_model, compile_spec)
87+
trt_mod = trtorch.compile(self.traced_model, **compile_spec)
8888
# Changing the device ID deliberately. It should still run on correct device ID by context switching
8989
trtorch.set_device(1)
9090
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
@@ -103,7 +103,7 @@ def test_compile_script(self):
103103
}
104104
}
105105

106-
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
106+
trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
107107
# Changing the device ID deliberately. It should still run on correct device ID by context switching
108108
trtorch.set_device(1)
109109
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()

Diff for: tests/py/test_ptq_dataloader_calibrator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_compile_script(self):
7272
}
7373
}
7474

75-
trt_mod = trtorch.compile(self.model, compile_spec)
75+
trt_mod = trtorch.compile(self.model, **compile_spec)
7676
int8_test_acc = self.compute_accuracy(self.testing_dataloader, trt_mod)
7777
log(Level.Info, "[TRT INT8] Test Acc: {:.2f}%".format(100 * int8_test_acc))
7878
acc_diff = fp32_test_acc - int8_test_acc

0 commit comments

Comments
 (0)