|
9 | 9 | #include "torch/custom_class.h"
|
10 | 10 | #include "torch/script.h"
|
11 | 11 | #include "torch/torch.h"
|
| 12 | +#include "util.h" |
12 | 13 |
|
13 | 14 | namespace py = pybind11;
|
14 | 15 |
|
15 | 16 | namespace trtorch {
|
16 | 17 | namespace pyapi {
|
17 | 18 |
|
| 19 | +template <typename Derived> |
| 20 | +class pyCalibratorTrampoline : public Derived { |
| 21 | + public: |
| 22 | + using Derived::Derived; // Inherit constructors |
| 23 | + |
| 24 | + int getBatchSize() const noexcept override { |
| 25 | + PYBIND11_OVERLOAD_PURE_NAME(int, Derived, "get_batch_size", getBatchSize); |
| 26 | + } |
| 27 | + |
| 28 | + bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override { |
| 29 | + py::gil_scoped_acquire gil{}; |
| 30 | + |
| 31 | + py::function pyGetBatch = trtorch::pyapi::util::getOverload(static_cast<Derived*>(this), "get_batch"); |
| 32 | + std::vector<const char*> namesVec(names, names + nbBindings); |
| 33 | + py::object result = pyGetBatch(namesVec); |
| 34 | + // Copy over into the other data structure. |
| 35 | + if (!result.is_none() && result.cast<std::vector<size_t>>().size() != 0) { |
| 36 | + std::memcpy(bindings, result.cast<std::vector<size_t>>().data(), nbBindings * sizeof(void*)); |
| 37 | + return true; |
| 38 | + } |
| 39 | + return false; |
| 40 | + } |
| 41 | + |
| 42 | + const void* readCalibrationCache(std::size_t& length) noexcept override { |
| 43 | + py::gil_scoped_acquire gil{}; |
| 44 | + |
| 45 | + py::function pyReadCalibrationCache = |
| 46 | + trtorch::pyapi::util::getOverload(static_cast<Derived*>(this), "read_calibration_cache"); |
| 47 | + py::buffer cache = pyReadCalibrationCache(); |
| 48 | + if (!cache.is_none()) { |
| 49 | + py::buffer_info info = cache.request(); |
| 50 | + length = info.size * info.itemsize; |
| 51 | + return info.ptr; |
| 52 | + } |
| 53 | + return nullptr; |
| 54 | + } |
| 55 | + |
| 56 | + void writeCalibrationCache(const void* ptr, std::size_t length) noexcept override { |
| 57 | + py::gil_scoped_acquire gil{}; |
| 58 | + |
| 59 | + py::function pyWriteCalibrationCache = |
| 60 | + trtorch::pyapi::util::getOverload(static_cast<Derived*>(this), "write_calibration_cache"); |
| 61 | + |
| 62 | + py::memoryview cache{py::memoryview::from_buffer(static_cast<const uint8_t*>(ptr), {length}, {sizeof(uint8_t)})}; |
| 63 | + pyWriteCalibrationCache(cache); |
| 64 | + } |
| 65 | +}; |
| 66 | + |
| 67 | +class pyIInt8Calibrator : public pyCalibratorTrampoline<nvinfer1::IInt8Calibrator> { |
| 68 | + public: |
| 69 | + using Derived = pyCalibratorTrampoline<nvinfer1::IInt8Calibrator>; |
| 70 | + using Derived::Derived; |
| 71 | + |
| 72 | + nvinfer1::CalibrationAlgoType getAlgorithm() noexcept override { |
| 73 | + PYBIND11_OVERLOAD_PURE_NAME( |
| 74 | + nvinfer1::CalibrationAlgoType, nvinfer1::IInt8Calibrator, "get_algorithm", getAlgorithm); |
| 75 | + } |
| 76 | +}; |
| 77 | + |
| 78 | +class pyIInt8LegacyCalibrator : public pyCalibratorTrampoline<nvinfer1::IInt8LegacyCalibrator> { |
| 79 | + public: |
| 80 | + using Derived = pyCalibratorTrampoline<nvinfer1::IInt8LegacyCalibrator>; |
| 81 | + using Derived::Derived; |
| 82 | + |
| 83 | + double getQuantile() const noexcept override { |
| 84 | + PYBIND11_OVERLOAD_PURE_NAME(double, nvinfer1::IInt8LegacyCalibrator, "get_quantile", getQuantile); |
| 85 | + } |
| 86 | + |
| 87 | + double getRegressionCutoff() const noexcept override { |
| 88 | + PYBIND11_OVERLOAD_PURE_NAME(double, nvinfer1::IInt8LegacyCalibrator, "get_regression_cutoff", getRegressionCutoff); |
| 89 | + } |
| 90 | + |
| 91 | + const void* readHistogramCache(std::size_t& length) noexcept override { |
| 92 | + PYBIND11_OVERLOAD_PURE_NAME( |
| 93 | + const void*, nvinfer1::IInt8LegacyCalibrator, "read_histogram_cache", readHistogramCache, length); |
| 94 | + } |
| 95 | + |
| 96 | + void writeHistogramCache(const void* ptr, std::size_t length) noexcept override { |
| 97 | + PYBIND11_OVERLOAD_PURE_NAME( |
| 98 | + void, nvinfer1::IInt8LegacyCalibrator, "write_histogram_cache", writeHistogramCache, ptr, length); |
| 99 | + } |
| 100 | +}; |
| 101 | + |
18 | 102 | void set_device(const int device_id) {
|
19 | 103 | core::set_device(device_id);
|
20 | 104 | }
|
@@ -102,10 +186,57 @@ PYBIND11_MODULE(_C, m) {
|
102 | 186 | .value("safe_dla", EngineCapability::kSAFE_DLA, "Use safety DLA kernels only")
|
103 | 187 | .value("default", EngineCapability::kDEFAULT, "Use default behavior");
|
104 | 188 |
|
| 189 | + py::enum_<nvinfer1::CalibrationAlgoType>(m, "CalibrationAlgo", py::module_local(), "Type of calibration algorithm") |
| 190 | + .value("LEGACY_CALIBRATION", nvinfer1::CalibrationAlgoType::kLEGACY_CALIBRATION) |
| 191 | + .value("ENTROPY_CALIBRATION", nvinfer1::CalibrationAlgoType::kENTROPY_CALIBRATION) |
| 192 | + .value("ENTROPY_CALIBRATION_2", nvinfer1::CalibrationAlgoType::kENTROPY_CALIBRATION_2) |
| 193 | + .value("MINMAX_CALIBRATION", nvinfer1::CalibrationAlgoType::kMINMAX_CALIBRATION); |
| 194 | + |
| 195 | + py::class_<nvinfer1::IInt8Calibrator, pyIInt8Calibrator>( |
| 196 | + m, "IInt8Calibrator", py::module_local(), "Int8 Calibrator base class") |
| 197 | + .def(py::init_alias<>()) // Always initialize trampoline class. |
| 198 | + .def("get_batch_size", &nvinfer1::IInt8Calibrator::getBatchSize, "Get batch size") |
| 199 | + .def("get_algorithm", &nvinfer1::IInt8Calibrator::getAlgorithm, "Get algorithm"); |
| 200 | + |
| 201 | + py::class_<nvinfer1::IInt8LegacyCalibrator, nvinfer1::IInt8Calibrator, pyIInt8LegacyCalibrator>( |
| 202 | + m, "IInt8LegacyCalibrator", py::module_local(), "Int8 Legacy Calibrator class") |
| 203 | + .def(py::init_alias<>()) // Always initialize trampoline class. |
| 204 | + .def("get_batch_size", &nvinfer1::IInt8LegacyCalibrator::getBatchSize, "Get batch size") |
| 205 | + .def("get_algorithm", &nvinfer1::IInt8LegacyCalibrator::getAlgorithm, "Get algorithm"); |
| 206 | + |
| 207 | + py::class_< |
| 208 | + nvinfer1::IInt8EntropyCalibrator, |
| 209 | + nvinfer1::IInt8Calibrator, |
| 210 | + pyCalibratorTrampoline<nvinfer1::IInt8EntropyCalibrator>>( |
| 211 | + m, "IInt8EntropyCalibrator", py::module_local(), "Int8 Entropy Calibrator class") |
| 212 | + .def(py::init_alias<>()) // Always initialize trampoline class. |
| 213 | + .def("get_batch_size", &nvinfer1::IInt8EntropyCalibrator::getBatchSize, "Get batch size") |
| 214 | + .def("get_algorithm", &nvinfer1::IInt8EntropyCalibrator::getAlgorithm, "Get algorithm"); |
| 215 | + |
| 216 | + py::class_< |
| 217 | + nvinfer1::IInt8EntropyCalibrator2, |
| 218 | + nvinfer1::IInt8Calibrator, |
| 219 | + pyCalibratorTrampoline<nvinfer1::IInt8EntropyCalibrator2>>( |
| 220 | + m, "IInt8EntropyCalibrator2", py::module_local(), "Int8 Entropy Calibrator2 class") |
| 221 | + .def(py::init_alias<>()) // Always initialize trampoline class. |
| 222 | + .def("get_batch_size", &nvinfer1::IInt8EntropyCalibrator2::getBatchSize, "Get batch size") |
| 223 | + .def("get_algorithm", &nvinfer1::IInt8EntropyCalibrator2::getAlgorithm, "Get algorithm"); |
| 224 | + |
| 225 | + py::class_< |
| 226 | + nvinfer1::IInt8MinMaxCalibrator, |
| 227 | + nvinfer1::IInt8Calibrator, |
| 228 | + pyCalibratorTrampoline<nvinfer1::IInt8MinMaxCalibrator>>( |
| 229 | + m, "IInt8MinMaxCalibrator", py::module_local(), "Int8 MinMax Calibrator class") |
| 230 | + .def(py::init_alias<>()) // Always initialize trampoline class. |
| 231 | + .def("get_batch_size", &nvinfer1::IInt8MinMaxCalibrator::getBatchSize, "Get batch size") |
| 232 | + .def("get_algorithm", &nvinfer1::IInt8MinMaxCalibrator::getAlgorithm, "Get algorithm"); |
| 233 | + |
105 | 234 | py::class_<CompileSpec>(m, "CompileSpec")
|
106 | 235 | .def(py::init<>())
|
| 236 | + .def("_get_calibrator_handle", &CompileSpec::getPTQCalibratorHandle, "[Internal] gets a handle from a calibrator") |
107 | 237 | .def_readwrite("input_ranges", &CompileSpec::input_ranges)
|
108 | 238 | .def_readwrite("op_precision", &CompileSpec::op_precision)
|
| 239 | + .def_readwrite("ptq_calibrator", &CompileSpec::ptq_calibrator) |
109 | 240 | .def_readwrite("refit", &CompileSpec::refit)
|
110 | 241 | .def_readwrite("disable_tf32", &CompileSpec::disable_tf32)
|
111 | 242 | .def_readwrite("debug", &CompileSpec::debug)
|
|
0 commit comments