Skip to content

feat(py/trtorch/ptq): Implement INT8 Python API for PTQ #390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ bazel-genfiles
bazel-out
bazel-testlogs
bazel-TRTorch
bazel-trtorch-testing
third_party/pytorch
*.jit
*.jit.pt
Expand Down Expand Up @@ -37,4 +38,6 @@ bdist
py/trtorch/_version.py
py/wheelhouse
py/.eggs
notebooks/.ipynb_checkpoints/
notebooks/.ipynb_checkpoints/
*.cache
tests/py/data
73 changes: 69 additions & 4 deletions docs/_sources/tutorials/ptq.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ the TensorRT calibrator. With TRTorch we look to leverage existing infrastructur
calibrators easier.

LibTorch provides a ``DataLoader`` and ``Dataset`` API which steamlines preprocessing and batching input data.
This section of the PyTorch documentation has more information https://pytorch.org/tutorials/advanced/cpp_frontend.html#loading-data.
These APIs are exposed via both C++ and Python interface which makes it easier for the end user.
For C++ interface, we use ``torch::Dataset`` and ``torch::data::make_data_loader`` objects to construct and perform pre-processing on datasets.
The equivalent functionality in python interface uses ``torch.utils.data.Dataset`` and ``torch.utils.data.DataLoader``.
This section of the PyTorch documentation has more information https://pytorch.org/tutorials/advanced/cpp_frontend.html#loading-data and https://pytorch.org/tutorials/recipes/recipes/loading_data_recipe.html.
TRTorch uses Dataloaders as the base of a generic calibrator implementation. So you will be able to reuse or quickly
implement a ``torch::Dataset`` for your target domain, place it in a DataLoader and create a INT8 Calibrator
which you can provide to TRTorch to run INT8 Calibration during compliation of your module.

.. _writing_ptq:
.. _writing_ptq_cpp:

How to create your own PTQ application
How to create your own PTQ application in C++
----------------------------------------

Here is an example interface of a ``torch::Dataset`` class for CIFAR10:
Expand Down Expand Up @@ -132,11 +135,73 @@ Then all thats required to setup the module for INT8 calibration is to set the f
auto trt_mod = trtorch::CompileGraph(mod, compile_spec);

If you have an existing Calibrator implementation for TensorRT you may directly set the ``ptq_calibrator`` field with a pointer to your calibrator and it will work as well.

From here not much changes in terms of how to execution works. You are still able to fully use LibTorch as the sole interface for inference. Data should remain
in FP32 precision when it's passed into `trt_mod.forward`. There exists an example application in the TRTorch demo that takes you from training a VGG16 network on
CIFAR10 to deploying in INT8 with TRTorch here: https://github.com/NVIDIA/TRTorch/tree/master/cpp/ptq

.. _writing_ptq_python:

How to create your own PTQ application in Python
----------------------------------------

TRTorch Python API provides an easy and convenient way to use pytorch dataloaders with TensorRT calibrators. ``DataLoaderCalibrator`` class can be used to create
a TensorRT calibrator by providing desired configuration. The following code demonstrates an example on how to use it

.. code-block:: python

testing_dataset = torchvision.datasets.CIFAR10(root='./data',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
]))

testing_dataloader = torch.utils.data.DataLoader(testing_dataset,
batch_size=1,
shuffle=False,
num_workers=1)
calibrator = trtorch.ptq.DataLoaderCalibrator(testing_dataloader,
cache_file='./calibration.cache',
use_cache=False,
algo_type=trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
device=torch.device('cuda:0'))

compile_spec = {
"input_shapes": [[1, 3, 32, 32]],
"op_precision": torch.int8,
"calibrator": calibrator,
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": False,
"disable_tf32": False
}
}
trt_mod = trtorch.compile(model, compile_spec)

In the cases where there is a pre-existing calibration cache file that users want to use, ``CacheCalibrator`` can be used without any dataloaders. The following example demonstrates how
to use ``CacheCalibrator`` to use in INT8 mode.

.. code-block:: python

calibrator = trtorch.ptq.CacheCalibrator("./calibration.cache")

compile_settings = {
"input_shapes": [[1, 3, 32, 32]],
"op_precision": torch.int8,
"calibrator": calibrator,
"max_batch_size": 32,
}

trt_mod = trtorch.compile(model, compile_settings)

If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient.
For a demo on how PTQ can be performed on a VGG network using TRTorch API, you can refer to https://github.com/NVIDIA/TRTorch/blob/master/tests/py/test_ptq_dataloader_calibrator.py
and https://github.com/NVIDIA/TRTorch/blob/master/tests/py/test_ptq_trt_calibrator.py

Citations
^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docsrc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#
import os
import sys

sys.path.append(os.path.join(os.path.dirname(__name__), '../py'))

import sphinx_material
Expand Down
3 changes: 2 additions & 1 deletion py/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ py_library(
"trtorch/__init__.py",
"trtorch/_version.py",
"trtorch/_compiler.py",
"trtorch/ptq.py",
"trtorch/_compile_spec.py",
"trtorch/_types.py",
"trtorch/logging.py"
Expand All @@ -21,4 +22,4 @@ py_library(
deps = [
requirement("torch")
]
)
)
1 change: 1 addition & 0 deletions py/trtorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from trtorch._version import __version__
from trtorch._compiler import *
from trtorch._compile_spec import TensorRTCompileSpec
from trtorch import ptq
from trtorch._types import *
from trtorch import logging

Expand Down
4 changes: 4 additions & 0 deletions py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
if "op_precision" in compile_spec:
info.op_precision = _parse_op_precision(compile_spec["op_precision"])

if "calibrator" in compile_spec:
info.ptq_calibrator = compile_spec["calibrator"]

if "disable_tf32" in compile_spec:
assert isinstance(compile_spec["disable_tf32"], bool)
info.disable_tf32 = compile_spec["disable_tf32"]
Expand Down Expand Up @@ -254,5 +257,6 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
backend_spec.set_num_avg_timing_iters(parsed_spec.num_avg_timing_iters)
backend_spec.set_workspace_size(parsed_spec.workspace_size)
backend_spec.set_max_batch_size(parsed_spec.max_batch_size)
backend_spec._set_ptq_calibrator(parsed_spec._get_calibrator_handle())

return backend_spec
1 change: 1 addition & 0 deletions py/trtorch/csrc/register_tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void RegisterTRTCompileSpec() {
.def(torch::init<>())
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
.def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
.def("_set_ptq_calibrator", &trtorch::pyapi::CompileSpec::setPTQCalibratorViaHandle)
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);

ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision);
Expand Down
1 change: 1 addition & 0 deletions py/trtorch/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
}
auto info = core::CompileSpec(internal_input_ranges);
info.convert_info.engine_settings.op_precision = toTRTDataType(op_precision);
info.convert_info.engine_settings.calibrator = ptq_calibrator;
info.convert_info.engine_settings.disable_tf32 = disable_tf32;
info.convert_info.engine_settings.refit = refit;
info.convert_info.engine_settings.debug = debug;
Expand Down
10 changes: 10 additions & 0 deletions py/trtorch/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,18 @@ struct CompileSpec : torch::CustomClassHolder {
input_ranges.push_back(*ir);
}

int64_t getPTQCalibratorHandle() {
return (int64_t)ptq_calibrator;
}

void setDeviceIntrusive(const c10::intrusive_ptr<Device>& d) {
device = *d;
}

void setPTQCalibratorViaHandle(int64_t handle) {
ptq_calibrator = (nvinfer1::IInt8Calibrator*)handle;
}

ADD_ENUM_GET_SET(op_precision, DataType, static_cast<int64_t>(DataType::kChar));
ADD_FIELD_GET_SET(disable_tf32, bool);
ADD_FIELD_GET_SET(refit, bool);
Expand All @@ -109,8 +117,10 @@ struct CompileSpec : torch::CustomClassHolder {
ADD_FIELD_GET_SET(workspace_size, int64_t);
ADD_FIELD_GET_SET(max_batch_size, int64_t);
ADD_FIELD_GET_SET(device, Device);
ADD_FIELD_GET_SET(ptq_calibrator, nvinfer1::IInt8Calibrator*);

std::vector<InputRange> input_ranges;
nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
DataType op_precision = DataType::kFloat;
bool disable_tf32 = false;
bool refit = false;
Expand Down
131 changes: 131 additions & 0 deletions py/trtorch/csrc/trtorch_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,96 @@
#include "torch/custom_class.h"
#include "torch/script.h"
#include "torch/torch.h"
#include "util.h"

namespace py = pybind11;

namespace trtorch {
namespace pyapi {

template <typename Derived>
class pyCalibratorTrampoline : public Derived {
public:
using Derived::Derived; // Inherit constructors

int getBatchSize() const noexcept override {
PYBIND11_OVERLOAD_PURE_NAME(int, Derived, "get_batch_size", getBatchSize);
}

bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override {
py::gil_scoped_acquire gil{};

py::function pyGetBatch = trtorch::pyapi::util::getOverload(static_cast<Derived*>(this), "get_batch");
std::vector<const char*> namesVec(names, names + nbBindings);
py::object result = pyGetBatch(namesVec);
// Copy over into the other data structure.
if (!result.is_none() && result.cast<std::vector<size_t>>().size() != 0) {
std::memcpy(bindings, result.cast<std::vector<size_t>>().data(), nbBindings * sizeof(void*));
return true;
}
return false;
}

const void* readCalibrationCache(std::size_t& length) noexcept override {
py::gil_scoped_acquire gil{};

py::function pyReadCalibrationCache =
trtorch::pyapi::util::getOverload(static_cast<Derived*>(this), "read_calibration_cache");
py::buffer cache = pyReadCalibrationCache();
if (!cache.is_none()) {
py::buffer_info info = cache.request();
length = info.size * info.itemsize;
return info.ptr;
}
return nullptr;
}

void writeCalibrationCache(const void* ptr, std::size_t length) noexcept override {
py::gil_scoped_acquire gil{};

py::function pyWriteCalibrationCache =
trtorch::pyapi::util::getOverload(static_cast<Derived*>(this), "write_calibration_cache");

py::memoryview cache{py::memoryview::from_buffer(static_cast<const uint8_t*>(ptr), {length}, {sizeof(uint8_t)})};
pyWriteCalibrationCache(cache);
}
};

class pyIInt8Calibrator : public pyCalibratorTrampoline<nvinfer1::IInt8Calibrator> {
public:
using Derived = pyCalibratorTrampoline<nvinfer1::IInt8Calibrator>;
using Derived::Derived;

nvinfer1::CalibrationAlgoType getAlgorithm() noexcept override {
PYBIND11_OVERLOAD_PURE_NAME(
nvinfer1::CalibrationAlgoType, nvinfer1::IInt8Calibrator, "get_algorithm", getAlgorithm);
}
};

class pyIInt8LegacyCalibrator : public pyCalibratorTrampoline<nvinfer1::IInt8LegacyCalibrator> {
public:
using Derived = pyCalibratorTrampoline<nvinfer1::IInt8LegacyCalibrator>;
using Derived::Derived;

double getQuantile() const noexcept override {
PYBIND11_OVERLOAD_PURE_NAME(double, nvinfer1::IInt8LegacyCalibrator, "get_quantile", getQuantile);
}

double getRegressionCutoff() const noexcept override {
PYBIND11_OVERLOAD_PURE_NAME(double, nvinfer1::IInt8LegacyCalibrator, "get_regression_cutoff", getRegressionCutoff);
}

const void* readHistogramCache(std::size_t& length) noexcept override {
PYBIND11_OVERLOAD_PURE_NAME(
const void*, nvinfer1::IInt8LegacyCalibrator, "read_histogram_cache", readHistogramCache, length);
}

void writeHistogramCache(const void* ptr, std::size_t length) noexcept override {
PYBIND11_OVERLOAD_PURE_NAME(
void, nvinfer1::IInt8LegacyCalibrator, "write_histogram_cache", writeHistogramCache, ptr, length);
}
};

void set_device(const int device_id) {
core::set_device(device_id);
}
Expand Down Expand Up @@ -102,10 +186,57 @@ PYBIND11_MODULE(_C, m) {
.value("safe_dla", EngineCapability::kSAFE_DLA, "Use safety DLA kernels only")
.value("default", EngineCapability::kDEFAULT, "Use default behavior");

py::enum_<nvinfer1::CalibrationAlgoType>(m, "CalibrationAlgo", py::module_local(), "Type of calibration algorithm")
.value("LEGACY_CALIBRATION", nvinfer1::CalibrationAlgoType::kLEGACY_CALIBRATION)
.value("ENTROPY_CALIBRATION", nvinfer1::CalibrationAlgoType::kENTROPY_CALIBRATION)
.value("ENTROPY_CALIBRATION_2", nvinfer1::CalibrationAlgoType::kENTROPY_CALIBRATION_2)
.value("MINMAX_CALIBRATION", nvinfer1::CalibrationAlgoType::kMINMAX_CALIBRATION);

py::class_<nvinfer1::IInt8Calibrator, pyIInt8Calibrator>(
m, "IInt8Calibrator", py::module_local(), "Int8 Calibrator base class")
.def(py::init_alias<>()) // Always initialize trampoline class.
.def("get_batch_size", &nvinfer1::IInt8Calibrator::getBatchSize, "Get batch size")
.def("get_algorithm", &nvinfer1::IInt8Calibrator::getAlgorithm, "Get algorithm");

py::class_<nvinfer1::IInt8LegacyCalibrator, nvinfer1::IInt8Calibrator, pyIInt8LegacyCalibrator>(
m, "IInt8LegacyCalibrator", py::module_local(), "Int8 Legacy Calibrator class")
.def(py::init_alias<>()) // Always initialize trampoline class.
.def("get_batch_size", &nvinfer1::IInt8LegacyCalibrator::getBatchSize, "Get batch size")
.def("get_algorithm", &nvinfer1::IInt8LegacyCalibrator::getAlgorithm, "Get algorithm");

py::class_<
nvinfer1::IInt8EntropyCalibrator,
nvinfer1::IInt8Calibrator,
pyCalibratorTrampoline<nvinfer1::IInt8EntropyCalibrator>>(
m, "IInt8EntropyCalibrator", py::module_local(), "Int8 Entropy Calibrator class")
.def(py::init_alias<>()) // Always initialize trampoline class.
.def("get_batch_size", &nvinfer1::IInt8EntropyCalibrator::getBatchSize, "Get batch size")
.def("get_algorithm", &nvinfer1::IInt8EntropyCalibrator::getAlgorithm, "Get algorithm");

py::class_<
nvinfer1::IInt8EntropyCalibrator2,
nvinfer1::IInt8Calibrator,
pyCalibratorTrampoline<nvinfer1::IInt8EntropyCalibrator2>>(
m, "IInt8EntropyCalibrator2", py::module_local(), "Int8 Entropy Calibrator2 class")
.def(py::init_alias<>()) // Always initialize trampoline class.
.def("get_batch_size", &nvinfer1::IInt8EntropyCalibrator2::getBatchSize, "Get batch size")
.def("get_algorithm", &nvinfer1::IInt8EntropyCalibrator2::getAlgorithm, "Get algorithm");

py::class_<
nvinfer1::IInt8MinMaxCalibrator,
nvinfer1::IInt8Calibrator,
pyCalibratorTrampoline<nvinfer1::IInt8MinMaxCalibrator>>(
m, "IInt8MinMaxCalibrator", py::module_local(), "Int8 MinMax Calibrator class")
.def(py::init_alias<>()) // Always initialize trampoline class.
.def("get_batch_size", &nvinfer1::IInt8MinMaxCalibrator::getBatchSize, "Get batch size")
.def("get_algorithm", &nvinfer1::IInt8MinMaxCalibrator::getAlgorithm, "Get algorithm");

py::class_<CompileSpec>(m, "CompileSpec")
.def(py::init<>())
.def("_get_calibrator_handle", &CompileSpec::getPTQCalibratorHandle, "[Internal] gets a handle from a calibrator")
.def_readwrite("input_ranges", &CompileSpec::input_ranges)
.def_readwrite("op_precision", &CompileSpec::op_precision)
.def_readwrite("ptq_calibrator", &CompileSpec::ptq_calibrator)
.def_readwrite("refit", &CompileSpec::refit)
.def_readwrite("disable_tf32", &CompileSpec::disable_tf32)
.def_readwrite("debug", &CompileSpec::debug)
Expand Down
31 changes: 31 additions & 0 deletions py/trtorch/csrc/util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <functional>
#include <iostream>
#include <string>
#include "core/util/prelude.h"

namespace trtorch {
namespace pyapi {
namespace util {

namespace py = pybind11;

// Method for calling the python function and returning the value (returned from python) used in cpp trampoline
// classes. Prints an error if no such method is overriden in python.
// T* must NOT be a trampoline class!
template <typename T>
py::function getOverload(const T* self, const std::string& overloadName) {
py::function overload = py::get_override(self, overloadName.c_str());
if (!overload) {
std::string msg{"Method: " + overloadName +
" was not overriden. Please provide an implementation for this method."};
LOG_ERROR(msg);
}
return overload;
}

} // namespace util
} // namespace pyapi
} // namespace trtorch
Loading