Skip to content

Commit b4b12a1

Browse files
authored
Merge pull request #390 from NVIDIA/int8_py
feat(py/trtorch/ptq): Implement INT8 Python API for PTQ
2 parents 9ff9c22 + a4e40ca commit b4b12a1

17 files changed

+802
-7
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ bazel-genfiles
55
bazel-out
66
bazel-testlogs
77
bazel-TRTorch
8+
bazel-trtorch-testing
89
third_party/pytorch
910
*.jit
1011
*.jit.pt
@@ -37,4 +38,6 @@ bdist
3738
py/trtorch/_version.py
3839
py/wheelhouse
3940
py/.eggs
40-
notebooks/.ipynb_checkpoints/
41+
notebooks/.ipynb_checkpoints/
42+
*.cache
43+
tests/py/data

docs/_sources/tutorials/ptq.rst.txt

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@ the TensorRT calibrator. With TRTorch we look to leverage existing infrastructur
1414
calibrators easier.
1515

1616
LibTorch provides a ``DataLoader`` and ``Dataset`` API which steamlines preprocessing and batching input data.
17-
This section of the PyTorch documentation has more information https://pytorch.org/tutorials/advanced/cpp_frontend.html#loading-data.
17+
These APIs are exposed via both C++ and Python interface which makes it easier for the end user.
18+
For C++ interface, we use ``torch::Dataset`` and ``torch::data::make_data_loader`` objects to construct and perform pre-processing on datasets.
19+
The equivalent functionality in python interface uses ``torch.utils.data.Dataset`` and ``torch.utils.data.DataLoader``.
20+
This section of the PyTorch documentation has more information https://pytorch.org/tutorials/advanced/cpp_frontend.html#loading-data and https://pytorch.org/tutorials/recipes/recipes/loading_data_recipe.html.
1821
TRTorch uses Dataloaders as the base of a generic calibrator implementation. So you will be able to reuse or quickly
1922
implement a ``torch::Dataset`` for your target domain, place it in a DataLoader and create a INT8 Calibrator
2023
which you can provide to TRTorch to run INT8 Calibration during compliation of your module.
2124

22-
.. _writing_ptq:
25+
.. _writing_ptq_cpp:
2326

24-
How to create your own PTQ application
27+
How to create your own PTQ application in C++
2528
----------------------------------------
2629

2730
Here is an example interface of a ``torch::Dataset`` class for CIFAR10:
@@ -132,11 +135,73 @@ Then all thats required to setup the module for INT8 calibration is to set the f
132135
auto trt_mod = trtorch::CompileGraph(mod, compile_spec);
133136

134137
If you have an existing Calibrator implementation for TensorRT you may directly set the ``ptq_calibrator`` field with a pointer to your calibrator and it will work as well.
135-
136138
From here not much changes in terms of how to execution works. You are still able to fully use LibTorch as the sole interface for inference. Data should remain
137139
in FP32 precision when it's passed into `trt_mod.forward`. There exists an example application in the TRTorch demo that takes you from training a VGG16 network on
138140
CIFAR10 to deploying in INT8 with TRTorch here: https://github.com/NVIDIA/TRTorch/tree/master/cpp/ptq
139141

142+
.. _writing_ptq_python:
143+
144+
How to create your own PTQ application in Python
145+
----------------------------------------
146+
147+
TRTorch Python API provides an easy and convenient way to use pytorch dataloaders with TensorRT calibrators. ``DataLoaderCalibrator`` class can be used to create
148+
a TensorRT calibrator by providing desired configuration. The following code demonstrates an example on how to use it
149+
150+
.. code-block:: python
151+
152+
testing_dataset = torchvision.datasets.CIFAR10(root='./data',
153+
train=False,
154+
download=True,
155+
transform=transforms.Compose([
156+
transforms.ToTensor(),
157+
transforms.Normalize((0.4914, 0.4822, 0.4465),
158+
(0.2023, 0.1994, 0.2010))
159+
]))
160+
161+
testing_dataloader = torch.utils.data.DataLoader(testing_dataset,
162+
batch_size=1,
163+
shuffle=False,
164+
num_workers=1)
165+
calibrator = trtorch.ptq.DataLoaderCalibrator(testing_dataloader,
166+
cache_file='./calibration.cache',
167+
use_cache=False,
168+
algo_type=trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
169+
device=torch.device('cuda:0'))
170+
171+
compile_spec = {
172+
"input_shapes": [[1, 3, 32, 32]],
173+
"op_precision": torch.int8,
174+
"calibrator": calibrator,
175+
"device": {
176+
"device_type": trtorch.DeviceType.GPU,
177+
"gpu_id": 0,
178+
"dla_core": 0,
179+
"allow_gpu_fallback": False,
180+
"disable_tf32": False
181+
}
182+
}
183+
trt_mod = trtorch.compile(model, compile_spec)
184+
185+
In the cases where there is a pre-existing calibration cache file that users want to use, ``CacheCalibrator`` can be used without any dataloaders. The following example demonstrates how
186+
to use ``CacheCalibrator`` to use in INT8 mode.
187+
188+
.. code-block:: python
189+
190+
calibrator = trtorch.ptq.CacheCalibrator("./calibration.cache")
191+
192+
compile_settings = {
193+
"input_shapes": [[1, 3, 32, 32]],
194+
"op_precision": torch.int8,
195+
"calibrator": calibrator,
196+
"max_batch_size": 32,
197+
}
198+
199+
trt_mod = trtorch.compile(model, compile_settings)
200+
201+
If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient.
202+
For a demo on how PTQ can be performed on a VGG network using TRTorch API, you can refer to https://github.com/NVIDIA/TRTorch/blob/master/tests/py/test_ptq_dataloader_calibrator.py
203+
and https://github.com/NVIDIA/TRTorch/blob/master/tests/py/test_ptq_trt_calibrator.py
204+
140205
Citations
141206
^^^^^^^^^^^
142207

docsrc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#
1313
import os
1414
import sys
15+
1516
sys.path.append(os.path.join(os.path.dirname(__name__), '../py'))
1617

1718
import sphinx_material

py/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ py_library(
99
"trtorch/__init__.py",
1010
"trtorch/_version.py",
1111
"trtorch/_compiler.py",
12+
"trtorch/ptq.py",
1213
"trtorch/_compile_spec.py",
1314
"trtorch/_types.py",
1415
"trtorch/logging.py"
@@ -21,4 +22,4 @@ py_library(
2122
deps = [
2223
requirement("torch")
2324
]
24-
)
25+
)

py/trtorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from trtorch._version import __version__
1111
from trtorch._compiler import *
1212
from trtorch._compile_spec import TensorRTCompileSpec
13+
from trtorch import ptq
1314
from trtorch._types import *
1415
from trtorch import logging
1516

py/trtorch/_compile_spec.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
135135
if "op_precision" in compile_spec:
136136
info.op_precision = _parse_op_precision(compile_spec["op_precision"])
137137

138+
if "calibrator" in compile_spec:
139+
info.ptq_calibrator = compile_spec["calibrator"]
140+
138141
if "disable_tf32" in compile_spec:
139142
assert isinstance(compile_spec["disable_tf32"], bool)
140143
info.disable_tf32 = compile_spec["disable_tf32"]
@@ -254,5 +257,6 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
254257
backend_spec.set_num_avg_timing_iters(parsed_spec.num_avg_timing_iters)
255258
backend_spec.set_workspace_size(parsed_spec.workspace_size)
256259
backend_spec.set_max_batch_size(parsed_spec.max_batch_size)
260+
backend_spec._set_ptq_calibrator(parsed_spec._get_calibrator_handle())
257261

258262
return backend_spec

py/trtorch/csrc/register_tensorrt_classes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ void RegisterTRTCompileSpec() {
2929
.def(torch::init<>())
3030
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
3131
.def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
32+
.def("_set_ptq_calibrator", &trtorch::pyapi::CompileSpec::setPTQCalibratorViaHandle)
3233
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);
3334

3435
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision);

py/trtorch/csrc/tensorrt_classes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
9999
}
100100
auto info = core::CompileSpec(internal_input_ranges);
101101
info.convert_info.engine_settings.op_precision = toTRTDataType(op_precision);
102+
info.convert_info.engine_settings.calibrator = ptq_calibrator;
102103
info.convert_info.engine_settings.disable_tf32 = disable_tf32;
103104
info.convert_info.engine_settings.refit = refit;
104105
info.convert_info.engine_settings.debug = debug;

py/trtorch/csrc/tensorrt_classes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,18 @@ struct CompileSpec : torch::CustomClassHolder {
9494
input_ranges.push_back(*ir);
9595
}
9696

97+
int64_t getPTQCalibratorHandle() {
98+
return (int64_t)ptq_calibrator;
99+
}
100+
97101
void setDeviceIntrusive(const c10::intrusive_ptr<Device>& d) {
98102
device = *d;
99103
}
100104

105+
void setPTQCalibratorViaHandle(int64_t handle) {
106+
ptq_calibrator = (nvinfer1::IInt8Calibrator*)handle;
107+
}
108+
101109
ADD_ENUM_GET_SET(op_precision, DataType, static_cast<int64_t>(DataType::kChar));
102110
ADD_FIELD_GET_SET(disable_tf32, bool);
103111
ADD_FIELD_GET_SET(refit, bool);
@@ -109,8 +117,10 @@ struct CompileSpec : torch::CustomClassHolder {
109117
ADD_FIELD_GET_SET(workspace_size, int64_t);
110118
ADD_FIELD_GET_SET(max_batch_size, int64_t);
111119
ADD_FIELD_GET_SET(device, Device);
120+
ADD_FIELD_GET_SET(ptq_calibrator, nvinfer1::IInt8Calibrator*);
112121

113122
std::vector<InputRange> input_ranges;
123+
nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
114124
DataType op_precision = DataType::kFloat;
115125
bool disable_tf32 = false;
116126
bool refit = false;

py/trtorch/csrc/trtorch_py.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,96 @@
99
#include "torch/custom_class.h"
1010
#include "torch/script.h"
1111
#include "torch/torch.h"
12+
#include "util.h"
1213

1314
namespace py = pybind11;
1415

1516
namespace trtorch {
1617
namespace pyapi {
1718

19+
template <typename Derived>
20+
class pyCalibratorTrampoline : public Derived {
21+
public:
22+
using Derived::Derived; // Inherit constructors
23+
24+
int getBatchSize() const noexcept override {
25+
PYBIND11_OVERLOAD_PURE_NAME(int, Derived, "get_batch_size", getBatchSize);
26+
}
27+
28+
bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override {
29+
py::gil_scoped_acquire gil{};
30+
31+
py::function pyGetBatch = trtorch::pyapi::util::getOverload(static_cast<Derived*>(this), "get_batch");
32+
std::vector<const char*> namesVec(names, names + nbBindings);
33+
py::object result = pyGetBatch(namesVec);
34+
// Copy over into the other data structure.
35+
if (!result.is_none() && result.cast<std::vector<size_t>>().size() != 0) {
36+
std::memcpy(bindings, result.cast<std::vector<size_t>>().data(), nbBindings * sizeof(void*));
37+
return true;
38+
}
39+
return false;
40+
}
41+
42+
const void* readCalibrationCache(std::size_t& length) noexcept override {
43+
py::gil_scoped_acquire gil{};
44+
45+
py::function pyReadCalibrationCache =
46+
trtorch::pyapi::util::getOverload(static_cast<Derived*>(this), "read_calibration_cache");
47+
py::buffer cache = pyReadCalibrationCache();
48+
if (!cache.is_none()) {
49+
py::buffer_info info = cache.request();
50+
length = info.size * info.itemsize;
51+
return info.ptr;
52+
}
53+
return nullptr;
54+
}
55+
56+
void writeCalibrationCache(const void* ptr, std::size_t length) noexcept override {
57+
py::gil_scoped_acquire gil{};
58+
59+
py::function pyWriteCalibrationCache =
60+
trtorch::pyapi::util::getOverload(static_cast<Derived*>(this), "write_calibration_cache");
61+
62+
py::memoryview cache{py::memoryview::from_buffer(static_cast<const uint8_t*>(ptr), {length}, {sizeof(uint8_t)})};
63+
pyWriteCalibrationCache(cache);
64+
}
65+
};
66+
67+
class pyIInt8Calibrator : public pyCalibratorTrampoline<nvinfer1::IInt8Calibrator> {
68+
public:
69+
using Derived = pyCalibratorTrampoline<nvinfer1::IInt8Calibrator>;
70+
using Derived::Derived;
71+
72+
nvinfer1::CalibrationAlgoType getAlgorithm() noexcept override {
73+
PYBIND11_OVERLOAD_PURE_NAME(
74+
nvinfer1::CalibrationAlgoType, nvinfer1::IInt8Calibrator, "get_algorithm", getAlgorithm);
75+
}
76+
};
77+
78+
class pyIInt8LegacyCalibrator : public pyCalibratorTrampoline<nvinfer1::IInt8LegacyCalibrator> {
79+
public:
80+
using Derived = pyCalibratorTrampoline<nvinfer1::IInt8LegacyCalibrator>;
81+
using Derived::Derived;
82+
83+
double getQuantile() const noexcept override {
84+
PYBIND11_OVERLOAD_PURE_NAME(double, nvinfer1::IInt8LegacyCalibrator, "get_quantile", getQuantile);
85+
}
86+
87+
double getRegressionCutoff() const noexcept override {
88+
PYBIND11_OVERLOAD_PURE_NAME(double, nvinfer1::IInt8LegacyCalibrator, "get_regression_cutoff", getRegressionCutoff);
89+
}
90+
91+
const void* readHistogramCache(std::size_t& length) noexcept override {
92+
PYBIND11_OVERLOAD_PURE_NAME(
93+
const void*, nvinfer1::IInt8LegacyCalibrator, "read_histogram_cache", readHistogramCache, length);
94+
}
95+
96+
void writeHistogramCache(const void* ptr, std::size_t length) noexcept override {
97+
PYBIND11_OVERLOAD_PURE_NAME(
98+
void, nvinfer1::IInt8LegacyCalibrator, "write_histogram_cache", writeHistogramCache, ptr, length);
99+
}
100+
};
101+
18102
void set_device(const int device_id) {
19103
core::set_device(device_id);
20104
}
@@ -102,10 +186,57 @@ PYBIND11_MODULE(_C, m) {
102186
.value("safe_dla", EngineCapability::kSAFE_DLA, "Use safety DLA kernels only")
103187
.value("default", EngineCapability::kDEFAULT, "Use default behavior");
104188

189+
py::enum_<nvinfer1::CalibrationAlgoType>(m, "CalibrationAlgo", py::module_local(), "Type of calibration algorithm")
190+
.value("LEGACY_CALIBRATION", nvinfer1::CalibrationAlgoType::kLEGACY_CALIBRATION)
191+
.value("ENTROPY_CALIBRATION", nvinfer1::CalibrationAlgoType::kENTROPY_CALIBRATION)
192+
.value("ENTROPY_CALIBRATION_2", nvinfer1::CalibrationAlgoType::kENTROPY_CALIBRATION_2)
193+
.value("MINMAX_CALIBRATION", nvinfer1::CalibrationAlgoType::kMINMAX_CALIBRATION);
194+
195+
py::class_<nvinfer1::IInt8Calibrator, pyIInt8Calibrator>(
196+
m, "IInt8Calibrator", py::module_local(), "Int8 Calibrator base class")
197+
.def(py::init_alias<>()) // Always initialize trampoline class.
198+
.def("get_batch_size", &nvinfer1::IInt8Calibrator::getBatchSize, "Get batch size")
199+
.def("get_algorithm", &nvinfer1::IInt8Calibrator::getAlgorithm, "Get algorithm");
200+
201+
py::class_<nvinfer1::IInt8LegacyCalibrator, nvinfer1::IInt8Calibrator, pyIInt8LegacyCalibrator>(
202+
m, "IInt8LegacyCalibrator", py::module_local(), "Int8 Legacy Calibrator class")
203+
.def(py::init_alias<>()) // Always initialize trampoline class.
204+
.def("get_batch_size", &nvinfer1::IInt8LegacyCalibrator::getBatchSize, "Get batch size")
205+
.def("get_algorithm", &nvinfer1::IInt8LegacyCalibrator::getAlgorithm, "Get algorithm");
206+
207+
py::class_<
208+
nvinfer1::IInt8EntropyCalibrator,
209+
nvinfer1::IInt8Calibrator,
210+
pyCalibratorTrampoline<nvinfer1::IInt8EntropyCalibrator>>(
211+
m, "IInt8EntropyCalibrator", py::module_local(), "Int8 Entropy Calibrator class")
212+
.def(py::init_alias<>()) // Always initialize trampoline class.
213+
.def("get_batch_size", &nvinfer1::IInt8EntropyCalibrator::getBatchSize, "Get batch size")
214+
.def("get_algorithm", &nvinfer1::IInt8EntropyCalibrator::getAlgorithm, "Get algorithm");
215+
216+
py::class_<
217+
nvinfer1::IInt8EntropyCalibrator2,
218+
nvinfer1::IInt8Calibrator,
219+
pyCalibratorTrampoline<nvinfer1::IInt8EntropyCalibrator2>>(
220+
m, "IInt8EntropyCalibrator2", py::module_local(), "Int8 Entropy Calibrator2 class")
221+
.def(py::init_alias<>()) // Always initialize trampoline class.
222+
.def("get_batch_size", &nvinfer1::IInt8EntropyCalibrator2::getBatchSize, "Get batch size")
223+
.def("get_algorithm", &nvinfer1::IInt8EntropyCalibrator2::getAlgorithm, "Get algorithm");
224+
225+
py::class_<
226+
nvinfer1::IInt8MinMaxCalibrator,
227+
nvinfer1::IInt8Calibrator,
228+
pyCalibratorTrampoline<nvinfer1::IInt8MinMaxCalibrator>>(
229+
m, "IInt8MinMaxCalibrator", py::module_local(), "Int8 MinMax Calibrator class")
230+
.def(py::init_alias<>()) // Always initialize trampoline class.
231+
.def("get_batch_size", &nvinfer1::IInt8MinMaxCalibrator::getBatchSize, "Get batch size")
232+
.def("get_algorithm", &nvinfer1::IInt8MinMaxCalibrator::getAlgorithm, "Get algorithm");
233+
105234
py::class_<CompileSpec>(m, "CompileSpec")
106235
.def(py::init<>())
236+
.def("_get_calibrator_handle", &CompileSpec::getPTQCalibratorHandle, "[Internal] gets a handle from a calibrator")
107237
.def_readwrite("input_ranges", &CompileSpec::input_ranges)
108238
.def_readwrite("op_precision", &CompileSpec::op_precision)
239+
.def_readwrite("ptq_calibrator", &CompileSpec::ptq_calibrator)
109240
.def_readwrite("refit", &CompileSpec::refit)
110241
.def_readwrite("disable_tf32", &CompileSpec::disable_tf32)
111242
.def_readwrite("debug", &CompileSpec::debug)

py/trtorch/csrc/util.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
#include <pybind11/numpy.h>
3+
#include <pybind11/pybind11.h>
4+
#include <functional>
5+
#include <iostream>
6+
#include <string>
7+
#include "core/util/prelude.h"
8+
9+
namespace trtorch {
10+
namespace pyapi {
11+
namespace util {
12+
13+
namespace py = pybind11;
14+
15+
// Method for calling the python function and returning the value (returned from python) used in cpp trampoline
16+
// classes. Prints an error if no such method is overriden in python.
17+
// T* must NOT be a trampoline class!
18+
template <typename T>
19+
py::function getOverload(const T* self, const std::string& overloadName) {
20+
py::function overload = py::get_override(self, overloadName.c_str());
21+
if (!overload) {
22+
std::string msg{"Method: " + overloadName +
23+
" was not overriden. Please provide an implementation for this method."};
24+
LOG_ERROR(msg);
25+
}
26+
return overload;
27+
}
28+
29+
} // namespace util
30+
} // namespace pyapi
31+
} // namespace trtorch

0 commit comments

Comments
 (0)