diff --git a/.gitignore b/.gitignore index b75f6cb62e..6fe7dc60df 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ bazel-genfiles bazel-out bazel-testlogs bazel-TRTorch +bazel-trtorch-testing third_party/pytorch *.jit *.jit.pt @@ -37,4 +38,6 @@ bdist py/trtorch/_version.py py/wheelhouse py/.eggs -notebooks/.ipynb_checkpoints/ \ No newline at end of file +notebooks/.ipynb_checkpoints/ +*.cache +tests/py/data \ No newline at end of file diff --git a/docs/_sources/tutorials/ptq.rst.txt b/docs/_sources/tutorials/ptq.rst.txt index 28d60acec3..f8e526d426 100644 --- a/docs/_sources/tutorials/ptq.rst.txt +++ b/docs/_sources/tutorials/ptq.rst.txt @@ -14,14 +14,17 @@ the TensorRT calibrator. With TRTorch we look to leverage existing infrastructur calibrators easier. LibTorch provides a ``DataLoader`` and ``Dataset`` API which steamlines preprocessing and batching input data. -This section of the PyTorch documentation has more information https://pytorch.org/tutorials/advanced/cpp_frontend.html#loading-data. +These APIs are exposed via both C++ and Python interface which makes it easier for the end user. +For C++ interface, we use ``torch::Dataset`` and ``torch::data::make_data_loader`` objects to construct and perform pre-processing on datasets. +The equivalent functionality in python interface uses ``torch.utils.data.Dataset`` and ``torch.utils.data.DataLoader``. +This section of the PyTorch documentation has more information https://pytorch.org/tutorials/advanced/cpp_frontend.html#loading-data and https://pytorch.org/tutorials/recipes/recipes/loading_data_recipe.html. TRTorch uses Dataloaders as the base of a generic calibrator implementation. So you will be able to reuse or quickly implement a ``torch::Dataset`` for your target domain, place it in a DataLoader and create a INT8 Calibrator which you can provide to TRTorch to run INT8 Calibration during compliation of your module. -.. _writing_ptq: +.. _writing_ptq_cpp: -How to create your own PTQ application +How to create your own PTQ application in C++ ---------------------------------------- Here is an example interface of a ``torch::Dataset`` class for CIFAR10: @@ -132,11 +135,73 @@ Then all thats required to setup the module for INT8 calibration is to set the f auto trt_mod = trtorch::CompileGraph(mod, compile_spec); If you have an existing Calibrator implementation for TensorRT you may directly set the ``ptq_calibrator`` field with a pointer to your calibrator and it will work as well. - From here not much changes in terms of how to execution works. You are still able to fully use LibTorch as the sole interface for inference. Data should remain in FP32 precision when it's passed into `trt_mod.forward`. There exists an example application in the TRTorch demo that takes you from training a VGG16 network on CIFAR10 to deploying in INT8 with TRTorch here: https://github.com/NVIDIA/TRTorch/tree/master/cpp/ptq +.. _writing_ptq_python: + +How to create your own PTQ application in Python +---------------------------------------- + +TRTorch Python API provides an easy and convenient way to use pytorch dataloaders with TensorRT calibrators. ``DataLoaderCalibrator`` class can be used to create +a TensorRT calibrator by providing desired configuration. The following code demonstrates an example on how to use it + +.. code-block:: python + + testing_dataset = torchvision.datasets.CIFAR10(root='./data', + train=False, + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)) + ])) + + testing_dataloader = torch.utils.data.DataLoader(testing_dataset, + batch_size=1, + shuffle=False, + num_workers=1) + calibrator = trtorch.ptq.DataLoaderCalibrator(testing_dataloader, + cache_file='./calibration.cache', + use_cache=False, + algo_type=trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2, + device=torch.device('cuda:0')) + + compile_spec = { + "input_shapes": [[1, 3, 32, 32]], + "op_precision": torch.int8, + "calibrator": calibrator, + "device": { + "device_type": trtorch.DeviceType.GPU, + "gpu_id": 0, + "dla_core": 0, + "allow_gpu_fallback": False, + "disable_tf32": False + } + } + trt_mod = trtorch.compile(model, compile_spec) + +In the cases where there is a pre-existing calibration cache file that users want to use, ``CacheCalibrator`` can be used without any dataloaders. The following example demonstrates how +to use ``CacheCalibrator`` to use in INT8 mode. + +.. code-block:: python + + calibrator = trtorch.ptq.CacheCalibrator("./calibration.cache") + + compile_settings = { + "input_shapes": [[1, 3, 32, 32]], + "op_precision": torch.int8, + "calibrator": calibrator, + "max_batch_size": 32, + } + + trt_mod = trtorch.compile(model, compile_settings) + +If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient. +For a demo on how PTQ can be performed on a VGG network using TRTorch API, you can refer to https://github.com/NVIDIA/TRTorch/blob/master/tests/py/test_ptq_dataloader_calibrator.py +and https://github.com/NVIDIA/TRTorch/blob/master/tests/py/test_ptq_trt_calibrator.py + Citations ^^^^^^^^^^^ diff --git a/docsrc/conf.py b/docsrc/conf.py index 251e5c6012..505fd6d942 100644 --- a/docsrc/conf.py +++ b/docsrc/conf.py @@ -12,6 +12,7 @@ # import os import sys + sys.path.append(os.path.join(os.path.dirname(__name__), '../py')) import sphinx_material diff --git a/py/BUILD b/py/BUILD index be5b2d7047..c6ca7efcf4 100644 --- a/py/BUILD +++ b/py/BUILD @@ -9,6 +9,7 @@ py_library( "trtorch/__init__.py", "trtorch/_version.py", "trtorch/_compiler.py", + "trtorch/ptq.py", "trtorch/_compile_spec.py", "trtorch/_types.py", "trtorch/logging.py" @@ -21,4 +22,4 @@ py_library( deps = [ requirement("torch") ] -) \ No newline at end of file +) diff --git a/py/trtorch/__init__.py b/py/trtorch/__init__.py index b61a3d2854..49e13e71d4 100644 --- a/py/trtorch/__init__.py +++ b/py/trtorch/__init__.py @@ -10,6 +10,7 @@ from trtorch._version import __version__ from trtorch._compiler import * from trtorch._compile_spec import TensorRTCompileSpec +from trtorch import ptq from trtorch._types import * from trtorch import logging diff --git a/py/trtorch/_compile_spec.py b/py/trtorch/_compile_spec.py index 814be63c13..4abab01790 100644 --- a/py/trtorch/_compile_spec.py +++ b/py/trtorch/_compile_spec.py @@ -135,6 +135,9 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec: if "op_precision" in compile_spec: info.op_precision = _parse_op_precision(compile_spec["op_precision"]) + if "calibrator" in compile_spec: + info.ptq_calibrator = compile_spec["calibrator"] + if "disable_tf32" in compile_spec: assert isinstance(compile_spec["disable_tf32"], bool) info.disable_tf32 = compile_spec["disable_tf32"] @@ -254,5 +257,6 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt. backend_spec.set_num_avg_timing_iters(parsed_spec.num_avg_timing_iters) backend_spec.set_workspace_size(parsed_spec.workspace_size) backend_spec.set_max_batch_size(parsed_spec.max_batch_size) + backend_spec._set_ptq_calibrator(parsed_spec._get_calibrator_handle()) return backend_spec diff --git a/py/trtorch/csrc/register_tensorrt_classes.cpp b/py/trtorch/csrc/register_tensorrt_classes.cpp index 048e69dbe1..36ff34931d 100644 --- a/py/trtorch/csrc/register_tensorrt_classes.cpp +++ b/py/trtorch/csrc/register_tensorrt_classes.cpp @@ -29,6 +29,7 @@ void RegisterTRTCompileSpec() { .def(torch::init<>()) .def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange) .def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive) + .def("_set_ptq_calibrator", &trtorch::pyapi::CompileSpec::setPTQCalibratorViaHandle) .def("__str__", &trtorch::pyapi::CompileSpec::stringify); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision); diff --git a/py/trtorch/csrc/tensorrt_classes.cpp b/py/trtorch/csrc/tensorrt_classes.cpp index 54b47d9111..e0b36c4463 100644 --- a/py/trtorch/csrc/tensorrt_classes.cpp +++ b/py/trtorch/csrc/tensorrt_classes.cpp @@ -99,6 +99,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() { } auto info = core::CompileSpec(internal_input_ranges); info.convert_info.engine_settings.op_precision = toTRTDataType(op_precision); + info.convert_info.engine_settings.calibrator = ptq_calibrator; info.convert_info.engine_settings.disable_tf32 = disable_tf32; info.convert_info.engine_settings.refit = refit; info.convert_info.engine_settings.debug = debug; diff --git a/py/trtorch/csrc/tensorrt_classes.h b/py/trtorch/csrc/tensorrt_classes.h index a60390d1e2..5371b93f75 100644 --- a/py/trtorch/csrc/tensorrt_classes.h +++ b/py/trtorch/csrc/tensorrt_classes.h @@ -94,10 +94,18 @@ struct CompileSpec : torch::CustomClassHolder { input_ranges.push_back(*ir); } + int64_t getPTQCalibratorHandle() { + return (int64_t)ptq_calibrator; + } + void setDeviceIntrusive(const c10::intrusive_ptr& d) { device = *d; } + void setPTQCalibratorViaHandle(int64_t handle) { + ptq_calibrator = (nvinfer1::IInt8Calibrator*)handle; + } + ADD_ENUM_GET_SET(op_precision, DataType, static_cast(DataType::kChar)); ADD_FIELD_GET_SET(disable_tf32, bool); ADD_FIELD_GET_SET(refit, bool); @@ -109,8 +117,10 @@ struct CompileSpec : torch::CustomClassHolder { ADD_FIELD_GET_SET(workspace_size, int64_t); ADD_FIELD_GET_SET(max_batch_size, int64_t); ADD_FIELD_GET_SET(device, Device); + ADD_FIELD_GET_SET(ptq_calibrator, nvinfer1::IInt8Calibrator*); std::vector input_ranges; + nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr; DataType op_precision = DataType::kFloat; bool disable_tf32 = false; bool refit = false; diff --git a/py/trtorch/csrc/trtorch_py.cpp b/py/trtorch/csrc/trtorch_py.cpp index 418423db41..e1a5b14eb4 100644 --- a/py/trtorch/csrc/trtorch_py.cpp +++ b/py/trtorch/csrc/trtorch_py.cpp @@ -9,12 +9,96 @@ #include "torch/custom_class.h" #include "torch/script.h" #include "torch/torch.h" +#include "util.h" namespace py = pybind11; namespace trtorch { namespace pyapi { +template +class pyCalibratorTrampoline : public Derived { + public: + using Derived::Derived; // Inherit constructors + + int getBatchSize() const noexcept override { + PYBIND11_OVERLOAD_PURE_NAME(int, Derived, "get_batch_size", getBatchSize); + } + + bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override { + py::gil_scoped_acquire gil{}; + + py::function pyGetBatch = trtorch::pyapi::util::getOverload(static_cast(this), "get_batch"); + std::vector namesVec(names, names + nbBindings); + py::object result = pyGetBatch(namesVec); + // Copy over into the other data structure. + if (!result.is_none() && result.cast>().size() != 0) { + std::memcpy(bindings, result.cast>().data(), nbBindings * sizeof(void*)); + return true; + } + return false; + } + + const void* readCalibrationCache(std::size_t& length) noexcept override { + py::gil_scoped_acquire gil{}; + + py::function pyReadCalibrationCache = + trtorch::pyapi::util::getOverload(static_cast(this), "read_calibration_cache"); + py::buffer cache = pyReadCalibrationCache(); + if (!cache.is_none()) { + py::buffer_info info = cache.request(); + length = info.size * info.itemsize; + return info.ptr; + } + return nullptr; + } + + void writeCalibrationCache(const void* ptr, std::size_t length) noexcept override { + py::gil_scoped_acquire gil{}; + + py::function pyWriteCalibrationCache = + trtorch::pyapi::util::getOverload(static_cast(this), "write_calibration_cache"); + + py::memoryview cache{py::memoryview::from_buffer(static_cast(ptr), {length}, {sizeof(uint8_t)})}; + pyWriteCalibrationCache(cache); + } +}; + +class pyIInt8Calibrator : public pyCalibratorTrampoline { + public: + using Derived = pyCalibratorTrampoline; + using Derived::Derived; + + nvinfer1::CalibrationAlgoType getAlgorithm() noexcept override { + PYBIND11_OVERLOAD_PURE_NAME( + nvinfer1::CalibrationAlgoType, nvinfer1::IInt8Calibrator, "get_algorithm", getAlgorithm); + } +}; + +class pyIInt8LegacyCalibrator : public pyCalibratorTrampoline { + public: + using Derived = pyCalibratorTrampoline; + using Derived::Derived; + + double getQuantile() const noexcept override { + PYBIND11_OVERLOAD_PURE_NAME(double, nvinfer1::IInt8LegacyCalibrator, "get_quantile", getQuantile); + } + + double getRegressionCutoff() const noexcept override { + PYBIND11_OVERLOAD_PURE_NAME(double, nvinfer1::IInt8LegacyCalibrator, "get_regression_cutoff", getRegressionCutoff); + } + + const void* readHistogramCache(std::size_t& length) noexcept override { + PYBIND11_OVERLOAD_PURE_NAME( + const void*, nvinfer1::IInt8LegacyCalibrator, "read_histogram_cache", readHistogramCache, length); + } + + void writeHistogramCache(const void* ptr, std::size_t length) noexcept override { + PYBIND11_OVERLOAD_PURE_NAME( + void, nvinfer1::IInt8LegacyCalibrator, "write_histogram_cache", writeHistogramCache, ptr, length); + } +}; + void set_device(const int device_id) { core::set_device(device_id); } @@ -102,10 +186,57 @@ PYBIND11_MODULE(_C, m) { .value("safe_dla", EngineCapability::kSAFE_DLA, "Use safety DLA kernels only") .value("default", EngineCapability::kDEFAULT, "Use default behavior"); + py::enum_(m, "CalibrationAlgo", py::module_local(), "Type of calibration algorithm") + .value("LEGACY_CALIBRATION", nvinfer1::CalibrationAlgoType::kLEGACY_CALIBRATION) + .value("ENTROPY_CALIBRATION", nvinfer1::CalibrationAlgoType::kENTROPY_CALIBRATION) + .value("ENTROPY_CALIBRATION_2", nvinfer1::CalibrationAlgoType::kENTROPY_CALIBRATION_2) + .value("MINMAX_CALIBRATION", nvinfer1::CalibrationAlgoType::kMINMAX_CALIBRATION); + + py::class_( + m, "IInt8Calibrator", py::module_local(), "Int8 Calibrator base class") + .def(py::init_alias<>()) // Always initialize trampoline class. + .def("get_batch_size", &nvinfer1::IInt8Calibrator::getBatchSize, "Get batch size") + .def("get_algorithm", &nvinfer1::IInt8Calibrator::getAlgorithm, "Get algorithm"); + + py::class_( + m, "IInt8LegacyCalibrator", py::module_local(), "Int8 Legacy Calibrator class") + .def(py::init_alias<>()) // Always initialize trampoline class. + .def("get_batch_size", &nvinfer1::IInt8LegacyCalibrator::getBatchSize, "Get batch size") + .def("get_algorithm", &nvinfer1::IInt8LegacyCalibrator::getAlgorithm, "Get algorithm"); + + py::class_< + nvinfer1::IInt8EntropyCalibrator, + nvinfer1::IInt8Calibrator, + pyCalibratorTrampoline>( + m, "IInt8EntropyCalibrator", py::module_local(), "Int8 Entropy Calibrator class") + .def(py::init_alias<>()) // Always initialize trampoline class. + .def("get_batch_size", &nvinfer1::IInt8EntropyCalibrator::getBatchSize, "Get batch size") + .def("get_algorithm", &nvinfer1::IInt8EntropyCalibrator::getAlgorithm, "Get algorithm"); + + py::class_< + nvinfer1::IInt8EntropyCalibrator2, + nvinfer1::IInt8Calibrator, + pyCalibratorTrampoline>( + m, "IInt8EntropyCalibrator2", py::module_local(), "Int8 Entropy Calibrator2 class") + .def(py::init_alias<>()) // Always initialize trampoline class. + .def("get_batch_size", &nvinfer1::IInt8EntropyCalibrator2::getBatchSize, "Get batch size") + .def("get_algorithm", &nvinfer1::IInt8EntropyCalibrator2::getAlgorithm, "Get algorithm"); + + py::class_< + nvinfer1::IInt8MinMaxCalibrator, + nvinfer1::IInt8Calibrator, + pyCalibratorTrampoline>( + m, "IInt8MinMaxCalibrator", py::module_local(), "Int8 MinMax Calibrator class") + .def(py::init_alias<>()) // Always initialize trampoline class. + .def("get_batch_size", &nvinfer1::IInt8MinMaxCalibrator::getBatchSize, "Get batch size") + .def("get_algorithm", &nvinfer1::IInt8MinMaxCalibrator::getAlgorithm, "Get algorithm"); + py::class_(m, "CompileSpec") .def(py::init<>()) + .def("_get_calibrator_handle", &CompileSpec::getPTQCalibratorHandle, "[Internal] gets a handle from a calibrator") .def_readwrite("input_ranges", &CompileSpec::input_ranges) .def_readwrite("op_precision", &CompileSpec::op_precision) + .def_readwrite("ptq_calibrator", &CompileSpec::ptq_calibrator) .def_readwrite("refit", &CompileSpec::refit) .def_readwrite("disable_tf32", &CompileSpec::disable_tf32) .def_readwrite("debug", &CompileSpec::debug) diff --git a/py/trtorch/csrc/util.h b/py/trtorch/csrc/util.h new file mode 100644 index 0000000000..22515f5a68 --- /dev/null +++ b/py/trtorch/csrc/util.h @@ -0,0 +1,31 @@ +#pragma once +#include +#include +#include +#include +#include +#include "core/util/prelude.h" + +namespace trtorch { +namespace pyapi { +namespace util { + +namespace py = pybind11; + +// Method for calling the python function and returning the value (returned from python) used in cpp trampoline +// classes. Prints an error if no such method is overriden in python. +// T* must NOT be a trampoline class! +template +py::function getOverload(const T* self, const std::string& overloadName) { + py::function overload = py::get_override(self, overloadName.c_str()); + if (!overload) { + std::string msg{"Method: " + overloadName + + " was not overriden. Please provide an implementation for this method."}; + LOG_ERROR(msg); + } + return overload; +} + +} // namespace util +} // namespace pyapi +} // namespace trtorch diff --git a/py/trtorch/ptq.py b/py/trtorch/ptq.py new file mode 100644 index 0000000000..a82b4bb6d5 --- /dev/null +++ b/py/trtorch/ptq.py @@ -0,0 +1,161 @@ +from typing import List, Dict, Any +import torch +import os + +import trtorch._C +from trtorch._compile_spec import _parse_compile_spec +from trtorch._version import __version__ +from trtorch.logging import * +from types import FunctionType +from enum import Enum + + +class CalibrationAlgo(Enum): + ENTROPY_CALIBRATION = trtorch._C.CalibrationAlgo.ENTROPY_CALIBRATION + ENTROPY_CALIBRATION_2 = trtorch._C.CalibrationAlgo.ENTROPY_CALIBRATION_2 + LEGACY_CALIBRATION = trtorch._C.CalibrationAlgo.LEGACY_CALIBRATION + MINMAX_CALIBRATION = trtorch._C.CalibrationAlgo.MINMAX_CALIBRATION + + +def get_cache_mode_batch(self): + return None + + +def get_batch_size(self): + return 1 + + +def get_batch(self, names): + if self.current_batch_idx + self.batch_size > self.data_loader.dataset.data.shape[0]: + return None + + batch = self.dataset_iterator.next() + self.current_batch_idx += self.batch_size + # Treat the first element as input and others as targets. + if isinstance(batch, list): + batch = batch[0].to(self.device) + return [batch.data_ptr()] + + +def read_calibration_cache(self): + if self.cache_file and self.use_cache: + if os.path.exists(self.cache_file): + with open(self.cache_file, "rb") as f: + return f.read() + + +def write_calibration_cache(self, cache): + if self.cache_file: + with open(self.cache_file, "wb") as f: + f.write(cache) + + +class DataLoaderCalibrator(object): + """ + Constructs a calibrator class in TensorRT and uses pytorch dataloader to load/preproces + data which is passed during calibration. + Args: + dataloader: an instance of pytorch dataloader which iterates through a given dataset. + algo_type: choice of calibration algorithm. + cache_file: path to cache file. + use_cache: flag which enables usage of pre-existing cache. + device: device on which calibration data is copied to. + """ + + def __init__(self, **kwargs): + pass + + def __new__(cls, *args, **kwargs): + dataloader = args[0] + algo_type = kwargs.get("algo_type", trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2) + cache_file = kwargs.get("cache_file", None) + use_cache = kwargs.get("use_cache", False) + device = kwargs.get("device", torch.device("cuda:0")) + + if not isinstance(dataloader, torch.utils.data.DataLoader): + log(Level.Error, + "Dataloader : {} is not a valid instance of torch.utils.data.DataLoader".format(dataloader)) + + if not cache_file: + if use_cache: + log(Level.Debug, "Using existing cache_file {} for calibration".format(cache_file)) + else: + log(Level.Debug, "Overwriting existing calibration cache file.") + else: + if use_cache: + log(Level.Error, "Input cache file is None but use_cache is set to True in INT8 mode.") + + # Define attributes and member functions for the calibrator class + attribute_mapping = { + 'data_loader': dataloader, + 'current_batch_idx': 0, + 'batch_size': dataloader.batch_size, + 'dataset_iterator': iter(dataloader), + 'cache_file': cache_file, + 'device': device, + 'use_cache': use_cache, + 'get_batch_size': get_batch_size, + 'get_batch': get_cache_mode_batch if use_cache else get_batch, + 'read_calibration_cache': read_calibration_cache, + 'write_calibration_cache': write_calibration_cache + } + + # Using type metaclass to construct calibrator class based on algorithm type + if algo_type == CalibrationAlgo.ENTROPY_CALIBRATION: + return type('DataLoaderCalibrator', (trtorch._C.IInt8EntropyCalibrator,), attribute_mapping)() + elif algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2: + return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), attribute_mapping)() + elif algo_type == CalibrationAlgo.LEGACY_CALIBRATION: + return type('DataLoaderCalibrator', (trtorch._C.IInt8LegacyCalibrator,), attribute_mapping)() + elif algo_type == CalibrationAlgo.MINMAX_CALIBRATION: + return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), attribute_mapping)() + else: + log( + Level.Error, + "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION" + ) + + +class CacheCalibrator(object): + """ + Constructs a calibrator class in TensorRT which directly uses pre-existing cache file for calibration. + Args: + cache_file: path to cache file. + algo_type: choice of calibration algorithm. + """ + + def __init__(self, **kwargs): + pass + + def __new__(cls, *args, **kwargs): + cache_file = args[0] + algo_type = kwargs.get("algo_type", trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2) + + if os.path.isfile(cache_file): + log(Level.Debug, "Using existing cache_file {} for calibration".format(cache_file)) + else: + log(Level.Error, "Invalid calibration cache file.") + + # Define attributes and member functions for the calibrator class + attribute_mapping = { + 'use_cache': True, + 'cache_file': cache_file, + 'get_batch_size': get_batch_size, + 'get_batch': get_cache_mode_batch, + 'read_calibration_cache': read_calibration_cache, + 'write_calibration_cache': write_calibration_cache + } + # Using type metaclass to construct calibrator class based on algorithm type + if algo_type == CalibrationAlgo.ENTROPY_CALIBRATION: + return type('DataLoaderCalibrator', (trtorch._C.IInt8EntropyCalibrator,), attribute_mapping)() + elif algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2: + return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), attribute_mapping)() + elif algo_type == CalibrationAlgo.LEGACY_CALIBRATION: + return type('DataLoaderCalibrator', (trtorch._C.IInt8LegacyCalibrator,), attribute_mapping)() + elif algo_type == CalibrationAlgo.MINMAX_CALIBRATION: + return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), attribute_mapping)() + else: + log( + Level.Error, + "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION" + ) diff --git a/tests/BUILD b/tests/BUILD index 3965e882d8..761f16781b 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -34,6 +34,14 @@ test_suite( name = "python_api_tests", tests = [ "//tests/py:test_api", - "//tests/py:test_to_backend_api" + "//tests/py:test_to_backend_api", ] ) + +test_suite( + name = "python_required_and_optional_tests", + tests = [ + ":python_api_tests", + "//tests/py:py_calibrator_tests" + ] +) \ No newline at end of file diff --git a/tests/py/BUILD b/tests/py/BUILD index 2f20daaf67..510b3f681e 100644 --- a/tests/py/BUILD +++ b/tests/py/BUILD @@ -25,6 +25,30 @@ py_test( ] ) +py_test( + name = "test_ptq_dataloader_calibrator", + srcs = [ + "test_ptq_dataloader_calibrator.py", + "model_test_case.py" + ] + deps = [ + requirement("torchvision") + ] +) + +# This test is not included in the main test suite by default. This test checks +# if trtorch can use pre-existing trt calibrators already implemented by users. +py_test( + name = "test_ptq_trt_calibrator", + srcs = [ + "test_ptq_trt_calibrator.py", + "model_test_case.py" + ] + deps = [ + requirement("torchvision") + ] +) + # Following multi_gpu test is only targeted for multi-gpu configurations. It is not included in the test suite by default. py_test( name = "test_multi_gpu", @@ -49,3 +73,23 @@ py_test( requirement("torchvision") ] ) + +py_test( + name = "test_ptq_to_backend", + srcs = [ + "test_ptq_to_backend.py", + "model_test_case.py" + ] + deps = [ + requirement("torchvision") + ] +) + +test_suite( + name = "py_calibrator_tests", + tests = [ + ":test_ptq_to_backend", + ":test_ptq_trt_calibrator", + ":test_ptq_dataloader_calibrator" + ], +) \ No newline at end of file diff --git a/tests/py/test_ptq_dataloader_calibrator.py b/tests/py/test_ptq_dataloader_calibrator.py new file mode 100644 index 0000000000..a22aeef3b9 --- /dev/null +++ b/tests/py/test_ptq_dataloader_calibrator.py @@ -0,0 +1,96 @@ +import unittest +import trtorch +from trtorch.logging import * +import torch +import torch.nn as nn +from torch.nn import functional as F +import torchvision +import torchvision.transforms as transforms +from model_test_case import ModelTestCase + + +class TestAccuracy(ModelTestCase): + + def setUp(self): + self.input = torch.randn((1, 3, 32, 32)).to("cuda") + self.testing_dataset = torchvision.datasets.CIFAR10(root='./data', + train=False, + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)) + ])) + + self.testing_dataloader = torch.utils.data.DataLoader(self.testing_dataset, + batch_size=1, + shuffle=False, + num_workers=1) + self.calibrator = trtorch.ptq.DataLoaderCalibrator(self.testing_dataloader, + cache_file='./calibration.cache', + use_cache=False, + algo_type=trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2, + device=torch.device('cuda:0')) + + def compute_accuracy(self, testing_dataloader, model): + total = 0 + correct = 0 + loss = 0.0 + class_probs = [] + class_preds = [] + device = torch.device('cuda:0') + with torch.no_grad(): + idx = 0 + for data, labels in testing_dataloader: + data, labels = data.to(device), labels.to(device) + out = model(data) + preds = torch.max(out, 1)[1] + class_probs.append([F.softmax(i, dim=0) for i in out]) + class_preds.append(preds) + total += labels.size(0) + correct += (preds == labels).sum().item() + idx += 1 + + test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) + test_preds = torch.cat(class_preds) + return correct / total + + def test_compile_script(self): + + fp32_test_acc = self.compute_accuracy(self.testing_dataloader, self.model) + log(Level.Info, "[Pyt FP32] Test Acc: {:.2f}%".format(100 * fp32_test_acc)) + + compile_spec = { + "input_shapes": [[1, 3, 32, 32]], + "op_precision": torch.int8, + "calibrator": self.calibrator, + "device": { + "device_type": trtorch.DeviceType.GPU, + "gpu_id": 0, + "dla_core": 0, + "allow_gpu_fallback": False, + } + } + + trt_mod = trtorch.compile(self.model, compile_spec) + int8_test_acc = self.compute_accuracy(self.testing_dataloader, trt_mod) + log(Level.Info, "[TRT INT8] Test Acc: {:.2f}%".format(100 * int8_test_acc)) + acc_diff = fp32_test_acc - int8_test_acc + self.assertTrue(abs(acc_diff) < 3) + + +def test_suite(): + suite = unittest.TestSuite() + # You need a pre-trained VGG cifar10 model to run this test. Please follow instructions at + # https://github.com/NVIDIA/TRTorch/tree/master/cpp/ptq/training/vgg16 to export this model. + suite.addTest(TestAccuracy.parametrize(TestAccuracy, model=torch.jit.load('./trained_vgg16.jit.pt'))) + + return suite + + +suite = test_suite() + +runner = unittest.TextTestRunner() +result = runner.run(suite) + +exit(int(not result.wasSuccessful())) diff --git a/tests/py/test_ptq_to_backend.py b/tests/py/test_ptq_to_backend.py new file mode 100644 index 0000000000..ae665dda71 --- /dev/null +++ b/tests/py/test_ptq_to_backend.py @@ -0,0 +1,99 @@ +import unittest +import trtorch +from trtorch.logging import * +import torch +import torch.nn as nn +from torch.nn import functional as F +import torchvision +import torchvision.transforms as transforms +from model_test_case import ModelTestCase + + +class TestAccuracy(ModelTestCase): + + def setUp(self): + self.input = torch.randn((1, 3, 32, 32)).to("cuda") + self.testing_dataset = torchvision.datasets.CIFAR10(root='./data', + train=False, + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)) + ])) + + self.testing_dataloader = torch.utils.data.DataLoader(self.testing_dataset, + batch_size=1, + shuffle=False, + num_workers=1) + self.calibrator = trtorch.ptq.DataLoaderCalibrator(self.testing_dataloader, + cache_file='./calibration.cache', + use_cache=False, + algo_type=trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2, + device=torch.device('cuda:0')) + + self.spec = { + "forward": + trtorch.TensorRTCompileSpec({ + "input_shapes": [[1, 3, 32, 32]], + "op_precision": torch.int8, + "calibrator": self.calibrator, + "device": { + "device_type": trtorch.DeviceType.GPU, + "gpu_id": 0, + "dla_core": 0, + "allow_gpu_fallback": False, + } + }) + } + + def compute_accuracy(self, testing_dataloader, model): + total = 0 + correct = 0 + loss = 0.0 + class_probs = [] + class_preds = [] + + with torch.no_grad(): + idx = 0 + for data, labels in testing_dataloader: + data, labels = data.cuda(), labels.cuda(non_blocking=True) + out = model(data) + preds = torch.max(out, 1)[1] + class_probs.append([F.softmax(i, dim=0) for i in out]) + class_preds.append(preds) + total += labels.size(0) + correct += (preds == labels).sum().item() + idx += 1 + + test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) + test_preds = torch.cat(class_preds) + return correct / total + + def test_compile_script(self): + + fp32_test_acc = self.compute_accuracy(self.testing_dataloader, self.model) + log(Level.Info, "[Pyt FP32] Test Acc: {:.2f}%".format(100 * fp32_test_acc)) + + trt_mod = torch._C._jit_to_backend("tensorrt", self.model, self.spec) + int8_test_acc = self.compute_accuracy(self.testing_dataloader, trt_mod) + log(Level.Info, "[TRT INT8 Backend] Test Acc: {:.2f}%".format(100 * int8_test_acc)) + acc_diff = fp32_test_acc - int8_test_acc + self.assertTrue(abs(acc_diff) < 3) + + +def test_suite(): + suite = unittest.TestSuite() + # You need a pre-trained VGG cifar10 model to run this test. Please follow instructions at + # https://github.com/NVIDIA/TRTorch/tree/master/cpp/ptq/training/vgg16 to export this model. + suite.addTest(TestAccuracy.parametrize(TestAccuracy, model=torch.jit.load('./trained_vgg16.jit.pt'))) + + return suite + + +suite = test_suite() + +runner = unittest.TextTestRunner() +result = runner.run(suite) + +exit(int(not result.wasSuccessful())) diff --git a/tests/py/test_ptq_trt_calibrator.py b/tests/py/test_ptq_trt_calibrator.py new file mode 100644 index 0000000000..737ecee4be --- /dev/null +++ b/tests/py/test_ptq_trt_calibrator.py @@ -0,0 +1,138 @@ +import unittest +import os +import trtorch +from trtorch.logging import * +import torch +import tensorrt as trt +import torch.nn as nn +from torch.nn import functional as F +import torchvision +import torchvision.transforms as transforms +from model_test_case import ModelTestCase + + +class TRTEntropyCalibrator(trt.IInt8EntropyCalibrator2): + + def __init__(self, dataloader, **kwargs): + trt.IInt8EntropyCalibrator2.__init__(self) + + self.cache_file = kwargs.get("cache_file", None) + self.use_cache = kwargs.get("use_cache", False) + self.device = kwargs.get("device", torch.device("cuda:0")) + + self.dataloader = dataloader + self.dataset_iterator = iter(dataloader) + self.batch_size = dataloader.batch_size + self.current_batch_idx = 0 + + def get_batch_size(self): + return 1 + + # TensorRT passes along the names of the engine bindings to the get_batch function. + # You don't necessarily have to use them, but they can be useful to understand the order of + # the inputs. The bindings list is expected to have the same ordering as 'names'. + def get_batch(self, names): + if self.current_batch_idx + self.batch_size > self.dataloader.dataset.data.shape[0]: + return None + + batch = self.dataset_iterator.next() + self.current_batch_idx += self.batch_size + # Treat the first element as input and others as targets. + if isinstance(batch, list): + batch = batch[0].to(self.device) + return [batch.data_ptr()] + + def read_calibration_cache(self): + # If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None. + if self.use_cache: + with open(self.cache_file, "rb") as f: + return f.read() + + def write_calibration_cache(self, cache): + if self.cache_file: + with open(self.cache_file, "wb") as f: + f.write(cache) + + +class TestAccuracy(ModelTestCase): + + def setUp(self): + self.input = torch.randn((1, 3, 32, 32)).to("cuda") + self.testing_dataset = torchvision.datasets.CIFAR10(root='./data', + train=False, + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)) + ])) + + self.testing_dataloader = torch.utils.data.DataLoader(self.testing_dataset, + batch_size=1, + shuffle=False, + num_workers=1) + # Test cases can assume using GPU id: 0 + self.calibrator = TRTEntropyCalibrator(self.testing_dataloader) + + def compute_accuracy(self, testing_dataloader, model): + total = 0 + correct = 0 + loss = 0.0 + class_probs = [] + class_preds = [] + device = torch.device('cuda:0') + with torch.no_grad(): + idx = 0 + for data, labels in testing_dataloader: + data, labels = data.to(device), labels.to(device) + out = model(data) + preds = torch.max(out, 1)[1] + class_probs.append([F.softmax(i, dim=0) for i in out]) + class_preds.append(preds) + total += labels.size(0) + correct += (preds == labels).sum().item() + idx += 1 + + test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) + test_preds = torch.cat(class_preds) + return correct / total + + def test_compile_script(self): + + fp32_test_acc = self.compute_accuracy(self.testing_dataloader, self.model) + log(Level.Info, "[Pyt FP32] Test Acc: {:.2f}%".format(100 * fp32_test_acc)) + + compile_spec = { + "input_shapes": [[1, 3, 32, 32]], + "op_precision": torch.int8, + "calibrator": self.calibrator, + "device": { + "device_type": trtorch.DeviceType.GPU, + "gpu_id": 0, + "dla_core": 0, + "allow_gpu_fallback": False, + } + } + + trt_mod = trtorch.compile(self.model, compile_spec) + int8_test_acc = self.compute_accuracy(self.testing_dataloader, trt_mod) + log(Level.Info, "[TRT INT8] Test Acc: {:.2f}%".format(100 * int8_test_acc)) + acc_diff = fp32_test_acc - int8_test_acc + self.assertTrue(abs(acc_diff) < 3) + + +def test_suite(): + suite = unittest.TestSuite() + # You need a pre-trained VGG cifar10 model to run this test. Please follow instructions at + # https://github.com/NVIDIA/TRTorch/tree/master/cpp/ptq/training/vgg16 to export this model. + suite.addTest(TestAccuracy.parametrize(TestAccuracy, model=torch.jit.load('./trained_vgg16.jit.pt'))) + + return suite + + +suite = test_suite() + +runner = unittest.TextTestRunner() +result = runner.run(suite) + +exit(int(not result.wasSuccessful()))