Skip to content

feat(py/trtorch/ptq): Implement INT8 Python API for PTQ #390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion py/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ py_library(
"trtorch/__init__.py",
"trtorch/_version.py",
"trtorch/_compiler.py",
"trtorch/ptq.py",
"trtorch/_compile_spec.py",
"trtorch/_types.py",
"trtorch/logging.py"
Expand All @@ -21,4 +22,4 @@ py_library(
deps = [
requirement("torch")
]
)
)
1 change: 1 addition & 0 deletions py/trtorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from trtorch._version import __version__
from trtorch._compiler import *
from trtorch._compile_spec import TensorRTCompileSpec
from trtorch import ptq
from trtorch._types import *
from trtorch import logging

Expand Down
3 changes: 3 additions & 0 deletions py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
if "op_precision" in compile_spec:
info.op_precision = _parse_op_precision(compile_spec["op_precision"])

if "calibrator" in compile_spec:
info.ptq_calibrator = compile_spec["calibrator"]

if "disable_tf32" in compile_spec:
assert isinstance(compile_spec["disable_tf32"], bool)
info.disable_tf32 = compile_spec["disable_tf32"]
Expand Down
1 change: 1 addition & 0 deletions py/trtorch/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
}
auto info = core::CompileSpec(internal_input_ranges);
info.convert_info.engine_settings.op_precision = toTRTDataType(op_precision);
info.convert_info.engine_settings.calibrator = ptq_calibrator;
info.convert_info.engine_settings.disable_tf32 = disable_tf32;
info.convert_info.engine_settings.refit = refit;
info.convert_info.engine_settings.debug = debug;
Expand Down
2 changes: 2 additions & 0 deletions py/trtorch/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ struct CompileSpec : torch::CustomClassHolder {
ADD_FIELD_GET_SET(workspace_size, int64_t);
ADD_FIELD_GET_SET(max_batch_size, int64_t);
ADD_FIELD_GET_SET(device, Device);
ADD_FIELD_GET_SET(ptq_calibrator, nvinfer1::IInt8Calibrator*);

std::vector<InputRange> input_ranges;
nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
DataType op_precision = DataType::kFloat;
bool disable_tf32 = false;
bool refit = false;
Expand Down
119 changes: 119 additions & 0 deletions py/trtorch/csrc/trtorch_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,95 @@
#include "torch/custom_class.h"
#include "torch/script.h"
#include "torch/torch.h"
#include "util.h"

using namespace nvinfer1;
using namespace pybind11::literals;
namespace py = pybind11;

namespace trtorch {
namespace pyapi {

template <typename Derived>
class pyCalibratorTrampoline : public Derived {
public:
using Derived::Derived; // Inherit constructors

int getBatchSize() const noexcept override {
PYBIND11_OVERLOAD_PURE_NAME(int, Derived, "get_batch_size", getBatchSize);
}

bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override {
py::gil_scoped_acquire gil{};

py::function pyGetBatch = trtorch::pyapi::util::getOverload(static_cast<Derived*>(this), "get_batch");
std::vector<const char*> namesVec(names, names + nbBindings);
py::object result = pyGetBatch(namesVec);
// Copy over into the other data structure.
if (!result.is_none() && result.cast<std::vector<size_t>>().size() != 0) {
std::memcpy(bindings, result.cast<std::vector<size_t>>().data(), nbBindings * sizeof(void*));
return true;
}
return false;
}

const void* readCalibrationCache(std::size_t& length) noexcept override {
py::gil_scoped_acquire gil{};

py::function pyReadCalibrationCache =
trtorch::pyapi::util::getOverload(static_cast<Derived*>(this), "read_calibration_cache");
py::buffer cache = pyReadCalibrationCache();
if (!cache.is_none()) {
py::buffer_info info = cache.request();
length = info.size * info.itemsize;
return info.ptr;
}
return nullptr;
}

void writeCalibrationCache(const void* ptr, std::size_t length) noexcept override {
py::gil_scoped_acquire gil{};

py::function pyWriteCalibrationCache =
trtorch::pyapi::util::getOverload(static_cast<Derived*>(this), "write_calibration_cache");

py::memoryview cache{py::memoryview::from_buffer(static_cast<const uint8_t*>(ptr), {length}, {sizeof(uint8_t)})};
pyWriteCalibrationCache(cache);
}
};

class pyIInt8Calibrator : public pyCalibratorTrampoline<IInt8Calibrator> {
public:
using Derived = pyCalibratorTrampoline<IInt8Calibrator>;
using Derived::Derived;

CalibrationAlgoType getAlgorithm() noexcept override {
PYBIND11_OVERLOAD_PURE_NAME(CalibrationAlgoType, IInt8Calibrator, "get_algorithm", getAlgorithm);
}
};

class pyIInt8LegacyCalibrator : public pyCalibratorTrampoline<IInt8LegacyCalibrator> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to support this at all? Since I think for the DataLoader calibrator we should only support current features. Is there a deprecation plan for the Legacy Calibrator?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked with TRT and this will atleast stay for the next release. I haven't seen any deprecation plan yet. I think there's no harm in supporting this. This still works when I ran the calibration.

public:
using Derived = pyCalibratorTrampoline<IInt8LegacyCalibrator>;
using Derived::Derived;

double getQuantile() const noexcept override {
PYBIND11_OVERLOAD_PURE_NAME(double, IInt8LegacyCalibrator, "get_quantile", getQuantile);
}

double getRegressionCutoff() const noexcept override {
PYBIND11_OVERLOAD_PURE_NAME(double, IInt8LegacyCalibrator, "get_regression_cutoff", getRegressionCutoff);
}

const void* readHistogramCache(std::size_t& length) noexcept override {
PYBIND11_OVERLOAD_PURE_NAME(const void*, IInt8LegacyCalibrator, "read_histogram_cache", readHistogramCache, length);
}

void writeHistogramCache(const void* ptr, std::size_t length) noexcept override {
PYBIND11_OVERLOAD_PURE_NAME(void, IInt8LegacyCalibrator, "write_histogram_cache", writeHistogramCache, ptr, length);
}
};

void set_device(const int device_id) {
core::set_device(device_id);
}
Expand Down Expand Up @@ -102,10 +185,46 @@ PYBIND11_MODULE(_C, m) {
.value("safe_dla", EngineCapability::kSAFE_DLA, "Use safety DLA kernels only")
.value("default", EngineCapability::kDEFAULT, "Use default behavior");

py::enum_<CalibrationAlgoType>(m, "CalibrationAlgo", py::module_local(), "Type of calibration algorithm")
.value("LEGACY_CALIBRATION", CalibrationAlgoType::kLEGACY_CALIBRATION)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question about legacy support

.value("ENTROPY_CALIBRATION", CalibrationAlgoType::kENTROPY_CALIBRATION)
.value("ENTROPY_CALIBRATION_2", CalibrationAlgoType::kENTROPY_CALIBRATION_2)
.value("MINMAX_CALIBRATION", CalibrationAlgoType::kMINMAX_CALIBRATION);

py::class_<IInt8Calibrator, pyIInt8Calibrator>(m, "IInt8Calibrator", py::module_local(), "Int8 Calibrator base class")
.def(py::init_alias<>()) // Always initialize trampoline class.
.def("get_batch_size", &IInt8Calibrator::getBatchSize, "Get batch size")
.def("get_algorithm", &IInt8Calibrator::getAlgorithm, "Get algorithm");

py::class_<IInt8LegacyCalibrator, IInt8Calibrator, pyIInt8LegacyCalibrator>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same again

m, "IInt8LegacyCalibrator", py::module_local(), "Int8 Legacy Calibrator class")
.def(py::init_alias<>()) // Always initialize trampoline class.
.def("get_batch_size", &IInt8LegacyCalibrator::getBatchSize, "Get batch size")
.def("get_algorithm", &IInt8LegacyCalibrator::getAlgorithm, "Get algorithm");

py::class_<IInt8EntropyCalibrator, IInt8Calibrator, pyCalibratorTrampoline<IInt8EntropyCalibrator>>(
m, "IInt8EntropyCalibrator", py::module_local(), "Int8 Entropy Calibrator class")
.def(py::init_alias<>()) // Always initialize trampoline class.
.def("get_batch_size", &IInt8EntropyCalibrator::getBatchSize, "Get batch size")
.def("get_algorithm", &IInt8EntropyCalibrator::getAlgorithm, "Get algorithm");

py::class_<IInt8EntropyCalibrator2, IInt8Calibrator, pyCalibratorTrampoline<IInt8EntropyCalibrator2>>(
m, "IInt8EntropyCalibrator2", py::module_local(), "Int8 Entropy Calibrator2 class")
.def(py::init_alias<>()) // Always initialize trampoline class.
.def("get_batch_size", &IInt8EntropyCalibrator2::getBatchSize, "Get batch size")
.def("get_algorithm", &IInt8EntropyCalibrator2::getAlgorithm, "Get algorithm");

py::class_<IInt8MinMaxCalibrator, IInt8Calibrator, pyCalibratorTrampoline<IInt8MinMaxCalibrator>>(
m, "IInt8MinMaxCalibrator", py::module_local(), "Int8 MinMax Calibrator class")
.def(py::init_alias<>()) // Always initialize trampoline class.
.def("get_batch_size", &IInt8MinMaxCalibrator::getBatchSize, "Get batch size")
.def("get_algorithm", &IInt8MinMaxCalibrator::getAlgorithm, "Get algorithm");

py::class_<CompileSpec>(m, "CompileSpec")
.def(py::init<>())
.def_readwrite("input_ranges", &CompileSpec::input_ranges)
.def_readwrite("op_precision", &CompileSpec::op_precision)
.def_readwrite("ptq_calibrator", &CompileSpec::ptq_calibrator)
.def_readwrite("refit", &CompileSpec::refit)
.def_readwrite("disable_tf32", &CompileSpec::disable_tf32)
.def_readwrite("debug", &CompileSpec::debug)
Expand Down
33 changes: 33 additions & 0 deletions py/trtorch/csrc/util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <functional>
#include <iostream>
#include <string>
#include "core/util/logging/TRTorchLogger.h"

namespace trtorch {
namespace pyapi {
namespace util {

namespace py = pybind11;

// Method for calling the python function and returning the value (returned from python) used in cpp trampoline
// classes. Prints an error if no such method is overriden in python.
// T* must NOT be a trampoline class!
template <typename T>
py::function getOverload(const T* self, const std::string& overloadName) {
py::function overload = py::get_override(self, overloadName.c_str());
if (!overload) {
std::string msg{"Method: " + overloadName +
" was not overriden. Please provide an implementation for this method."};
// std::cerr << "Method: " << overloadName << " was not overriden. Please provide an implementation for this
// method.";
core::util::logging::get_logger().log(core::util::logging::LogLevel::kERROR, msg);
}
return overload;
}

} // namespace util
} // namespace pyapi
} // namespace trtorch
106 changes: 106 additions & 0 deletions py/trtorch/ptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import List, Dict, Any
import torch
import os

import trtorch._C
from trtorch._compile_spec import _parse_compile_spec
from trtorch._version import __version__
from types import FunctionType
from enum import Enum


class CalibrationAlgo(Enum):
ENTROPY_CALIBRATION = trtorch._C.CalibrationAlgo.ENTROPY_CALIBRATION
ENTROPY_CALIBRATION_2 = trtorch._C.CalibrationAlgo.ENTROPY_CALIBRATION_2
LEGACY_CALIBRATION = trtorch._C.CalibrationAlgo.LEGACY_CALIBRATION
MINMAX_CALIBRATION = trtorch._C.CalibrationAlgo.MINMAX_CALIBRATION

def get_cache_mode_batch(self):
return None

def get_batch_size(self):
return 1

def get_batch(self, names):
if self.current_batch_idx + self.batch_size > self.data_loader.dataset.data.shape[0]:
return None
batch = self.dataset_iterator.next()
self.current_batch_idx += self.batch_size
print("Calibrating batch: ", self.current_batch_idx)
# Treat the first element as input and others as targets.
if isinstance(batch, list):
batch = batch[0].to(torch.device('cuda:0'))
return [batch.data_ptr()]

def read_calibration_cache(self):
if self.use_cache:
if os.path.exists(self.cache_file):
with open(self.cache_file, "rb") as f:
return f.read()

def write_calibration_cache(self, cache):
with open(self.cache_file, "wb") as f:
f.write(cache)

class DataLoaderCalibrator(object):
def __init__(self, dataloader, cache_file, use_cache, algo_type):
self.algo_type = algo_type
if use_cache:
if os.path.isfile(cache_file):
print("Using existing cache file for calibration ", cache_file)
else:
raise ValueError("use_cache flag is True but cache file not found.")

# Define attributes and member functions for the calibrator class
self.attribute_mapping={'data_loader' : dataloader,
'current_batch_idx' : 0,
'batch_size' : dataloader.batch_size,
'dataset_iterator' : iter(dataloader),
'cache_file' : cache_file,
'use_cache' : use_cache,
'get_batch_size' : get_batch_size,
'get_batch': get_cache_mode_batch if use_cache else get_batch,
'read_calibration_cache' : read_calibration_cache,
'write_calibration_cache' : write_calibration_cache}

def __call__(self):
# Using type metaclass to construct calibrator class based on algorithm type
if self.algo_type == CalibrationAlgo.ENTROPY_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8EntropyCalibrator,), self.attribute_mapping)()
elif self.algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2:
return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), self.attribute_mapping)()
elif self.algo_type == CalibrationAlgo.LEGACY_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8LegacyCalibrator,), self.attribute_mapping)()
elif self.algo_type == CalibrationAlgo.MINMAX_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), self.attribute_mapping)()
else:
return ValueError("Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION");

class CacheCalibrator(object):
def __init__(self, cache_file, algo_type):
self.algo_type = algo_type
if os.path.isfile(cache_file):
print("Using cache file for calibration ", cache_file)
else:
raise ValueError("Calibration cache file not found at ", cache_file)

# Define attributes and member functions for the calibrator class
self.attribute_mapping={'use_cache' : True,
'cache_file' : cache_file,
'get_batch_size' : get_batch_size,
'get_batch': get_cache_mode_batch,
'read_calibration_cache' : read_calibration_cache,
'write_calibration_cache' : write_calibration_cache}

def __call__(self):
# Using type metaclass to construct calibrator class based on algorithm type
if self.algo_type == CalibrationAlgo.ENTROPY_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8EntropyCalibrator,), self.attribute_mapping)()
elif self.algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2:
return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), self.attribute_mapping)()
elif self.algo_type == CalibrationAlgo.LEGACY_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8LegacyCalibrator,), self.attribute_mapping)()
elif self.algo_type == CalibrationAlgo.MINMAX_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), self.attribute_mapping)()
else:
return ValueError("Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION");