-
Notifications
You must be signed in to change notification settings - Fork 363
feat(py/trtorch/ptq): Implement INT8 Python API for PTQ #390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 13 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
9376527
Add int8 calibration API in python
peri044 f497fe2
py binding changes
peri044 3f4466c
Introduce INT8 API and various support
peri044 9ce54a9
Fix merge conflicts
peri044 dff26d3
Add utils and remove redundant headers
peri044 eea3114
Remove redundant changes
peri044 0dd7083
Add PTQ module
peri044 86916bd
Address review comments part-1
peri044 9f41818
Change interface of calibrator
peri044 e86e932
Address review comments - part 2
peri044 71adb44
Add test suite for PTQ python API
peri044 322a415
Change class instantiation
peri044 fe5654f
Fix linter
peri044 92aa6c5
Address review comments
peri044 6c3e0ad
feat(//py): Allowing people using the PyTorch backend to use TRTorch/TRT
narendasan b4484b4
Merge branch 'int8_py' of https://github.com/NVIDIA/TRTorch into int8_py
narendasan 076bab0
refactor(//tests): Couple of edits and organization for new int8 py
narendasan f309262
refactor(//tests): Apply linting
narendasan 088d586
chore(//tests): Adding instructions on where to get model for PTQ python
narendasan a4e40ca
Merge pull request #398 from NVIDIA/to_backend_int8
narendasan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,12 +9,97 @@ | |
#include "torch/custom_class.h" | ||
#include "torch/script.h" | ||
#include "torch/torch.h" | ||
#include "util.h" | ||
|
||
// using namespace nvinfer1; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. delete this line |
||
namespace py = pybind11; | ||
|
||
namespace trtorch { | ||
namespace pyapi { | ||
|
||
template <typename Derived> | ||
class pyCalibratorTrampoline : public Derived { | ||
public: | ||
using Derived::Derived; // Inherit constructors | ||
|
||
int getBatchSize() const noexcept override { | ||
PYBIND11_OVERLOAD_PURE_NAME(int, Derived, "get_batch_size", getBatchSize); | ||
} | ||
|
||
bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override { | ||
py::gil_scoped_acquire gil{}; | ||
|
||
py::function pyGetBatch = trtorch::pyapi::util::getOverload(static_cast<Derived*>(this), "get_batch"); | ||
std::vector<const char*> namesVec(names, names + nbBindings); | ||
py::object result = pyGetBatch(namesVec); | ||
// Copy over into the other data structure. | ||
if (!result.is_none() && result.cast<std::vector<size_t>>().size() != 0) { | ||
std::memcpy(bindings, result.cast<std::vector<size_t>>().data(), nbBindings * sizeof(void*)); | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
const void* readCalibrationCache(std::size_t& length) noexcept override { | ||
py::gil_scoped_acquire gil{}; | ||
|
||
py::function pyReadCalibrationCache = | ||
trtorch::pyapi::util::getOverload(static_cast<Derived*>(this), "read_calibration_cache"); | ||
py::buffer cache = pyReadCalibrationCache(); | ||
if (!cache.is_none()) { | ||
py::buffer_info info = cache.request(); | ||
length = info.size * info.itemsize; | ||
return info.ptr; | ||
} | ||
return nullptr; | ||
} | ||
|
||
void writeCalibrationCache(const void* ptr, std::size_t length) noexcept override { | ||
py::gil_scoped_acquire gil{}; | ||
|
||
py::function pyWriteCalibrationCache = | ||
trtorch::pyapi::util::getOverload(static_cast<Derived*>(this), "write_calibration_cache"); | ||
|
||
py::memoryview cache{py::memoryview::from_buffer(static_cast<const uint8_t*>(ptr), {length}, {sizeof(uint8_t)})}; | ||
pyWriteCalibrationCache(cache); | ||
} | ||
}; | ||
|
||
class pyIInt8Calibrator : public pyCalibratorTrampoline<nvinfer1::IInt8Calibrator> { | ||
public: | ||
using Derived = pyCalibratorTrampoline<nvinfer1::IInt8Calibrator>; | ||
using Derived::Derived; | ||
|
||
nvinfer1::CalibrationAlgoType getAlgorithm() noexcept override { | ||
PYBIND11_OVERLOAD_PURE_NAME( | ||
nvinfer1::CalibrationAlgoType, nvinfer1::IInt8Calibrator, "get_algorithm", getAlgorithm); | ||
} | ||
}; | ||
|
||
class pyIInt8LegacyCalibrator : public pyCalibratorTrampoline<nvinfer1::IInt8LegacyCalibrator> { | ||
public: | ||
using Derived = pyCalibratorTrampoline<nvinfer1::IInt8LegacyCalibrator>; | ||
using Derived::Derived; | ||
|
||
double getQuantile() const noexcept override { | ||
PYBIND11_OVERLOAD_PURE_NAME(double, nvinfer1::IInt8LegacyCalibrator, "get_quantile", getQuantile); | ||
} | ||
|
||
double getRegressionCutoff() const noexcept override { | ||
PYBIND11_OVERLOAD_PURE_NAME(double, nvinfer1::IInt8LegacyCalibrator, "get_regression_cutoff", getRegressionCutoff); | ||
} | ||
|
||
const void* readHistogramCache(std::size_t& length) noexcept override { | ||
PYBIND11_OVERLOAD_PURE_NAME( | ||
const void*, nvinfer1::IInt8LegacyCalibrator, "read_histogram_cache", readHistogramCache, length); | ||
} | ||
|
||
void writeHistogramCache(const void* ptr, std::size_t length) noexcept override { | ||
PYBIND11_OVERLOAD_PURE_NAME( | ||
void, nvinfer1::IInt8LegacyCalibrator, "write_histogram_cache", writeHistogramCache, ptr, length); | ||
} | ||
}; | ||
|
||
void set_device(const int device_id) { | ||
core::set_device(device_id); | ||
} | ||
|
@@ -102,10 +187,56 @@ PYBIND11_MODULE(_C, m) { | |
.value("safe_dla", EngineCapability::kSAFE_DLA, "Use safety DLA kernels only") | ||
.value("default", EngineCapability::kDEFAULT, "Use default behavior"); | ||
|
||
py::enum_<nvinfer1::CalibrationAlgoType>(m, "CalibrationAlgo", py::module_local(), "Type of calibration algorithm") | ||
.value("LEGACY_CALIBRATION", nvinfer1::CalibrationAlgoType::kLEGACY_CALIBRATION) | ||
.value("ENTROPY_CALIBRATION", nvinfer1::CalibrationAlgoType::kENTROPY_CALIBRATION) | ||
.value("ENTROPY_CALIBRATION_2", nvinfer1::CalibrationAlgoType::kENTROPY_CALIBRATION_2) | ||
.value("MINMAX_CALIBRATION", nvinfer1::CalibrationAlgoType::kMINMAX_CALIBRATION); | ||
|
||
py::class_<nvinfer1::IInt8Calibrator, pyIInt8Calibrator>( | ||
m, "IInt8Calibrator", py::module_local(), "Int8 Calibrator base class") | ||
.def(py::init_alias<>()) // Always initialize trampoline class. | ||
.def("get_batch_size", &nvinfer1::IInt8Calibrator::getBatchSize, "Get batch size") | ||
.def("get_algorithm", &nvinfer1::IInt8Calibrator::getAlgorithm, "Get algorithm"); | ||
|
||
py::class_<nvinfer1::IInt8LegacyCalibrator, nvinfer1::IInt8Calibrator, pyIInt8LegacyCalibrator>( | ||
m, "IInt8LegacyCalibrator", py::module_local(), "Int8 Legacy Calibrator class") | ||
.def(py::init_alias<>()) // Always initialize trampoline class. | ||
.def("get_batch_size", &nvinfer1::IInt8LegacyCalibrator::getBatchSize, "Get batch size") | ||
.def("get_algorithm", &nvinfer1::IInt8LegacyCalibrator::getAlgorithm, "Get algorithm"); | ||
|
||
py::class_< | ||
nvinfer1::IInt8EntropyCalibrator, | ||
nvinfer1::IInt8Calibrator, | ||
pyCalibratorTrampoline<nvinfer1::IInt8EntropyCalibrator>>( | ||
m, "IInt8EntropyCalibrator", py::module_local(), "Int8 Entropy Calibrator class") | ||
.def(py::init_alias<>()) // Always initialize trampoline class. | ||
.def("get_batch_size", &nvinfer1::IInt8EntropyCalibrator::getBatchSize, "Get batch size") | ||
.def("get_algorithm", &nvinfer1::IInt8EntropyCalibrator::getAlgorithm, "Get algorithm"); | ||
|
||
py::class_< | ||
nvinfer1::IInt8EntropyCalibrator2, | ||
nvinfer1::IInt8Calibrator, | ||
pyCalibratorTrampoline<nvinfer1::IInt8EntropyCalibrator2>>( | ||
m, "IInt8EntropyCalibrator2", py::module_local(), "Int8 Entropy Calibrator2 class") | ||
.def(py::init_alias<>()) // Always initialize trampoline class. | ||
.def("get_batch_size", &nvinfer1::IInt8EntropyCalibrator2::getBatchSize, "Get batch size") | ||
.def("get_algorithm", &nvinfer1::IInt8EntropyCalibrator2::getAlgorithm, "Get algorithm"); | ||
|
||
py::class_< | ||
nvinfer1::IInt8MinMaxCalibrator, | ||
nvinfer1::IInt8Calibrator, | ||
pyCalibratorTrampoline<nvinfer1::IInt8MinMaxCalibrator>>( | ||
m, "IInt8MinMaxCalibrator", py::module_local(), "Int8 MinMax Calibrator class") | ||
.def(py::init_alias<>()) // Always initialize trampoline class. | ||
.def("get_batch_size", &nvinfer1::IInt8MinMaxCalibrator::getBatchSize, "Get batch size") | ||
.def("get_algorithm", &nvinfer1::IInt8MinMaxCalibrator::getAlgorithm, "Get algorithm"); | ||
|
||
py::class_<CompileSpec>(m, "CompileSpec") | ||
.def(py::init<>()) | ||
.def_readwrite("input_ranges", &CompileSpec::input_ranges) | ||
.def_readwrite("op_precision", &CompileSpec::op_precision) | ||
.def_readwrite("ptq_calibrator", &CompileSpec::ptq_calibrator) | ||
.def_readwrite("refit", &CompileSpec::refit) | ||
.def_readwrite("disable_tf32", &CompileSpec::disable_tf32) | ||
.def_readwrite("debug", &CompileSpec::debug) | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#pragma once | ||
#include <pybind11/numpy.h> | ||
#include <pybind11/pybind11.h> | ||
#include <functional> | ||
#include <iostream> | ||
#include <string> | ||
#include "core/util/prelude.h" | ||
|
||
namespace trtorch { | ||
namespace pyapi { | ||
namespace util { | ||
|
||
namespace py = pybind11; | ||
|
||
// Method for calling the python function and returning the value (returned from python) used in cpp trampoline | ||
// classes. Prints an error if no such method is overriden in python. | ||
// T* must NOT be a trampoline class! | ||
template <typename T> | ||
py::function getOverload(const T* self, const std::string& overloadName) { | ||
py::function overload = py::get_override(self, overloadName.c_str()); | ||
if (!overload) { | ||
std::string msg{"Method: " + overloadName + | ||
" was not overriden. Please provide an implementation for this method."}; | ||
LOG_ERROR(msg); | ||
} | ||
return overload; | ||
} | ||
|
||
} // namespace util | ||
} // namespace pyapi | ||
} // namespace trtorch |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets remove the
self
s here