Skip to content

Commit 6c3e0ad

Browse files
committed
feat(//py): Allowing people using the PyTorch backend to use TRTorch/TRT
INT8 calibrators Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent fe5654f commit 6c3e0ad

File tree

6 files changed

+112
-1
lines changed

6 files changed

+112
-1
lines changed

Diff for: .gitignore

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ bazel-genfiles
55
bazel-out
66
bazel-testlogs
77
bazel-TRTorch
8+
bazel-trtorch-testing
89
third_party/pytorch
910
*.jit
1011
*.jit.pt
@@ -37,4 +38,6 @@ bdist
3738
py/trtorch/_version.py
3839
py/wheelhouse
3940
py/.eggs
40-
notebooks/.ipynb_checkpoints/
41+
notebooks/.ipynb_checkpoints/
42+
*.cache
43+
tests/py/data

Diff for: py/trtorch/_compile_spec.py

+1
Original file line numberDiff line numberDiff line change
@@ -257,5 +257,6 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
257257
backend_spec.set_num_avg_timing_iters(parsed_spec.num_avg_timing_iters)
258258
backend_spec.set_workspace_size(parsed_spec.workspace_size)
259259
backend_spec.set_max_batch_size(parsed_spec.max_batch_size)
260+
backend_spec._set_ptq_calibrator(parsed_spec._get_calibrator_handle())
260261

261262
return backend_spec

Diff for: py/trtorch/csrc/register_tensorrt_classes.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ void RegisterTRTCompileSpec() {
2929
.def(torch::init<>())
3030
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
3131
.def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
32+
.def("_set_ptq_calibrator", &trtorch::pyapi::CompileSpec::setPTQCalibratorViaHandle)
3233
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);
3334

3435
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision);

Diff for: py/trtorch/csrc/tensorrt_classes.h

+8
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,18 @@ struct CompileSpec : torch::CustomClassHolder {
9494
input_ranges.push_back(*ir);
9595
}
9696

97+
int64_t getPTQCalibratorHandle() {
98+
return (int64_t)ptq_calibrator;
99+
}
100+
97101
void setDeviceIntrusive(const c10::intrusive_ptr<Device>& d) {
98102
device = *d;
99103
}
100104

105+
void setPTQCalibratorViaHandle(int64_t handle) {
106+
ptq_calibrator = (nvinfer1::IInt8Calibrator*)handle;
107+
}
108+
101109
ADD_ENUM_GET_SET(op_precision, DataType, static_cast<int64_t>(DataType::kChar));
102110
ADD_FIELD_GET_SET(disable_tf32, bool);
103111
ADD_FIELD_GET_SET(refit, bool);

Diff for: py/trtorch/csrc/trtorch_py.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ PYBIND11_MODULE(_C, m) {
234234

235235
py::class_<CompileSpec>(m, "CompileSpec")
236236
.def(py::init<>())
237+
.def("_get_calibrator_handle", &CompileSpec::getPTQCalibratorHandle, "[Internal] gets a handle from a calibrator")
237238
.def_readwrite("input_ranges", &CompileSpec::input_ranges)
238239
.def_readwrite("op_precision", &CompileSpec::op_precision)
239240
.def_readwrite("ptq_calibrator", &CompileSpec::ptq_calibrator)

Diff for: tests/py/test_ptq_to_backend.py

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import unittest
2+
import trtorch
3+
from trtorch.logging import *
4+
import torch
5+
import torch.nn as nn
6+
from torch.nn import functional as F
7+
import torchvision
8+
import torchvision.transforms as transforms
9+
from model_test_case import ModelTestCase
10+
11+
12+
class TestAccuracy(ModelTestCase):
13+
14+
def setUp(self):
15+
self.input = torch.randn((1, 3, 32, 32)).to("cuda")
16+
self.testing_dataset = torchvision.datasets.CIFAR10(root='./data',
17+
train=False,
18+
download=True,
19+
transform=transforms.Compose([
20+
transforms.ToTensor(),
21+
transforms.Normalize((0.4914, 0.4822, 0.4465),
22+
(0.2023, 0.1994, 0.2010))
23+
]))
24+
25+
self.testing_dataloader = torch.utils.data.DataLoader(self.testing_dataset,
26+
batch_size=1,
27+
shuffle=False,
28+
num_workers=1)
29+
self.calibrator = trtorch.ptq.DataLoaderCalibrator(self.testing_dataloader,
30+
cache_file='./calibration.cache',
31+
use_cache=False,
32+
algo_type=trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
33+
device=torch.device('cuda:0'))
34+
35+
self.spec = {
36+
"forward":
37+
trtorch.TensorRTCompileSpec({
38+
"input_shapes": [[1, 3, 32, 32]],
39+
"op_precision": torch.int8,
40+
"calibrator": self.calibrator,
41+
"device": {
42+
"device_type": trtorch.DeviceType.GPU,
43+
"gpu_id": 0,
44+
"dla_core": 0,
45+
"allow_gpu_fallback": False,
46+
}
47+
})
48+
}
49+
50+
def compute_accuracy(self, testing_dataloader, model):
51+
total = 0
52+
correct = 0
53+
loss = 0.0
54+
class_probs = []
55+
class_preds = []
56+
57+
with torch.no_grad():
58+
idx = 0
59+
for data, labels in testing_dataloader:
60+
data, labels = data.cuda(), labels.cuda(non_blocking=True)
61+
out = model(data)
62+
preds = torch.max(out, 1)[1]
63+
class_probs.append([F.softmax(i, dim=0) for i in out])
64+
class_preds.append(preds)
65+
total += labels.size(0)
66+
correct += (preds == labels).sum().item()
67+
idx += 1
68+
69+
test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
70+
test_preds = torch.cat(class_preds)
71+
return correct / total
72+
73+
def test_compile_script(self):
74+
75+
fp32_test_acc = self.compute_accuracy(self.testing_dataloader, self.model)
76+
log(Level.Info, "[Pyt FP32] Test Acc: {:.2f}%".format(100 * fp32_test_acc))
77+
78+
trt_mod = torch._C._jit_to_backend("tensorrt", self.model, self.spec)
79+
int8_test_acc = self.compute_accuracy(self.testing_dataloader, trt_mod)
80+
log(Level.Info, "[TRT INT8 Backend] Test Acc: {:.2f}%".format(100 * int8_test_acc))
81+
acc_diff = fp32_test_acc - int8_test_acc
82+
self.assertTrue(abs(acc_diff) < 3)
83+
84+
85+
def test_suite():
86+
suite = unittest.TestSuite()
87+
suite.addTest(TestAccuracy.parametrize(TestAccuracy, model=torch.jit.load('./trained_vgg16.jit.pt')))
88+
89+
return suite
90+
91+
92+
suite = test_suite()
93+
94+
runner = unittest.TextTestRunner()
95+
result = runner.run(suite)
96+
97+
exit(int(not result.wasSuccessful()))

0 commit comments

Comments
 (0)