Skip to content

feat(py/trtorch/ptq): Implement INT8 Python API for PTQ #390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Mar 17, 2021
Merged

Conversation

peri044
Copy link
Collaborator

@peri044 peri044 commented Mar 7, 2021

Description

Implement INT8 PTQ Python API support. Discussion thread : https://github.com/NVIDIA/TRTorch/discussions/346
Three main goals addressed:

  • Implement data loader calibrators and cache variants in python
  • Re-export TensorRT bindings through TRTorch. This includes support for using TRT calibrators directly in TRTorch without symbol conflicts
  • Use python meta classes instead of writing separate data loader calibrators for different types of calibration algorithm

Fixes #173, #57

Type of change

  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes

Signed-off-by: Dheeraj Peri <[email protected]>
Signed-off-by: Dheeraj Peri <[email protected]>
Signed-off-by: Dheeraj Peri <[email protected]>
Signed-off-by: Dheeraj Peri <[email protected]>
@github-actions github-actions bot added component: api [Python] Issues re: Python API component: api [C++] Issues re: C++ API labels Mar 7, 2021
@peri044 peri044 requested a review from narendasan March 7, 2021 10:57
@peri044 peri044 changed the title Implement INT8 Python API for PTQ feat(py/trtorch/ptq): Implement INT8 Python API for PTQ Mar 7, 2021
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/workspace/cpp/api/include/trtorch/trtorch.h b/tmp/changes.txt
index 7e107ef..dc32c66 100644
--- a/workspace/cpp/api/include/trtorch/trtorch.h
+++ b/tmp/changes.txt
@@ -34,7 +34,7 @@ class IInt8Calibrator;
}
#endif // DOXYGEN_SHOULD_SKIP_THIS

-#include "trtorch/macros.h" // 
+#include "trtorch/macros.h" //
namespace trtorch {
/**
 * Settings data structure for TRTorch compilation
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /workspace/py/trtorch/ptq.py	(original)
+++ /workspace/py/trtorch/ptq.py	(reformatted)
@@ -8,11 +8,14 @@
from types import FunctionType
import tensorrt as trt

+
def get_cache_mode_batch(self):
    return None

+
def get_batch_size(self):
    return self.batch_size
+

def get_batch(self, names):
    if self.current_batch_idx + self.batch_size > self.data_loader.dataset.data.shape[0]:
@@ -25,28 +28,33 @@
        batch = batch[0].to(torch.device('cuda:0'))
    return [batch.data_ptr()]

+
def read_calibration_cache(self):
    if self.use_cache:
        if os.path.exists(self.cache_file):
            with open(self.cache_file, "rb") as f:
                return f.read()

+
def write_calibration_cache(self, cache):
    with open(self.cache_file, "wb") as f:
        f.write(cache)

+
def make_int8_calibrator(dataloader, cache_file, use_cache, algo_type):
    # Define attributes and member functions for the calibrator class
-    attribute_mapping={'data_loader' : dataloader,
-                       'current_batch_idx' : 0,
-                       'batch_size' : dataloader.batch_size,
-                       'dataset_iterator' : iter(dataloader),
-                       'cache_file' : cache_file,
-                       'use_cache' : use_cache,
-                       'get_batch_size' : get_batch_size,
-                       'get_batch': get_cache_mode_batch if use_cache else get_batch,
-                       'read_calibration_cache' : read_calibration_cache,
-                       'write_calibration_cache' : write_calibration_cache}
+    attribute_mapping = {
+        'data_loader': dataloader,
+        'current_batch_idx': 0,
+        'batch_size': dataloader.batch_size,
+        'dataset_iterator': iter(dataloader),
+        'cache_file': cache_file,
+        'use_cache': use_cache,
+        'get_batch_size': get_batch_size,
+        'get_batch': get_cache_mode_batch if use_cache else get_batch,
+        'read_calibration_cache': read_calibration_cache,
+        'write_calibration_cache': write_calibration_cache
+    }

    # Using type metaclass to construct calibrator class based on algorithm type
    if algo_type == trtorch._C.CalibrationAlgo.ENTROPY_CALIBRATION:
@@ -58,4 +66,6 @@
    elif algo_type == trtorch._C.CalibrationAlgo.MINMAX_CALIBRATION:
        return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), attribute_mapping)()
    else:
-        return ValueError("Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION");
+        return ValueError(
+            "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION"
+        )
Reformatting /workspace/py/trtorch/_compile_spec.py
Reformatting /workspace/py/trtorch/_types.py
Reformatting /workspace/py/trtorch/ptq.py
Reformatting /workspace/py/trtorch/_compiler.py
Reformatting /workspace/py/trtorch/logging.py
Reformatting /workspace/py/trtorch/__init__.py
Reformatting /workspace/py/setup.py
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/workspace/cpp/api/include/trtorch/trtorch.h b/tmp/changes.txt
index 7e107ef..dc32c66 100644
--- a/workspace/cpp/api/include/trtorch/trtorch.h
+++ b/tmp/changes.txt
@@ -34,7 +34,7 @@ class IInt8Calibrator;
}
#endif // DOXYGEN_SHOULD_SKIP_THIS

-#include "trtorch/macros.h" // 
+#include "trtorch/macros.h" //
namespace trtorch {
/**
 * Settings data structure for TRTorch compilation
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /workspace/py/trtorch/ptq.py	(original)
+++ /workspace/py/trtorch/ptq.py	(reformatted)
@@ -8,11 +8,14 @@
from types import FunctionType
import tensorrt as trt

+
def get_cache_mode_batch(self):
    return None

+
def get_batch_size(self):
    return self.batch_size
+

def get_batch(self, names):
    if self.current_batch_idx + self.batch_size > self.data_loader.dataset.data.shape[0]:
@@ -25,28 +28,33 @@
        batch = batch[0].to(torch.device('cuda:0'))
    return [batch.data_ptr()]

+
def read_calibration_cache(self):
    if self.use_cache:
        if os.path.exists(self.cache_file):
            with open(self.cache_file, "rb") as f:
                return f.read()

+
def write_calibration_cache(self, cache):
    with open(self.cache_file, "wb") as f:
        f.write(cache)

+
def make_int8_calibrator(dataloader, cache_file, use_cache, algo_type):
    # Define attributes and member functions for the calibrator class
-    attribute_mapping={'data_loader' : dataloader,
-                       'current_batch_idx' : 0,
-                       'batch_size' : dataloader.batch_size,
-                       'dataset_iterator' : iter(dataloader),
-                       'cache_file' : cache_file,
-                       'use_cache' : use_cache,
-                       'get_batch_size' : get_batch_size,
-                       'get_batch': get_cache_mode_batch if use_cache else get_batch,
-                       'read_calibration_cache' : read_calibration_cache,
-                       'write_calibration_cache' : write_calibration_cache}
+    attribute_mapping = {
+        'data_loader': dataloader,
+        'current_batch_idx': 0,
+        'batch_size': dataloader.batch_size,
+        'dataset_iterator': iter(dataloader),
+        'cache_file': cache_file,
+        'use_cache': use_cache,
+        'get_batch_size': get_batch_size,
+        'get_batch': get_cache_mode_batch if use_cache else get_batch,
+        'read_calibration_cache': read_calibration_cache,
+        'write_calibration_cache': write_calibration_cache
+    }

    # Using type metaclass to construct calibrator class based on algorithm type
    if algo_type == trtorch._C.CalibrationAlgo.ENTROPY_CALIBRATION:
@@ -58,4 +66,6 @@
    elif algo_type == trtorch._C.CalibrationAlgo.MINMAX_CALIBRATION:
        return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), attribute_mapping)()
    else:
-        return ValueError("Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION");
+        return ValueError(
+            "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION"
+        )
Reformatting /workspace/py/trtorch/_compile_spec.py
Reformatting /workspace/py/trtorch/_types.py
Reformatting /workspace/py/trtorch/ptq.py
Reformatting /workspace/py/trtorch/_compiler.py
Reformatting /workspace/py/trtorch/logging.py
Reformatting /workspace/py/trtorch/__init__.py
Reformatting /workspace/py/setup.py
ERROR: Some files do not conform to style guidelines

}
};

class pyIInt8LegacyCalibrator : public pyCalibratorTrampoline<IInt8LegacyCalibrator> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to support this at all? Since I think for the DataLoader calibrator we should only support current features. Is there a deprecation plan for the Legacy Calibrator?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked with TRT and this will atleast stay for the next release. I haven't seen any deprecation plan yet. I think there's no harm in supporting this. This still works when I ran the calibration.

@narendasan
Copy link
Collaborator

Also seems like we still need the CacheCalibrator implementation in addition to tests for all three paths

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /workspace/py/trtorch/ptq.py	(original)
+++ /workspace/py/trtorch/ptq.py	(reformatted)
@@ -12,11 +12,14 @@
LEGACY_CALIBRATION = trtorch._C.CalibrationAlgo.LEGACY_CALIBRATION
MINMAX_CALIBRATION = trtorch._C.CalibrationAlgo.MINMAX_CALIBRATION

+
def get_cache_mode_batch(self):
    return None

+
def get_batch_size(self):
    return self.batch_size
+

def get_batch(self, names):
    if self.current_batch_idx + self.batch_size > self.data_loader.dataset.data.shape[0]:
@@ -29,28 +32,33 @@
        batch = batch[0].to(torch.device('cuda:0'))
    return [batch.data_ptr()]

+
def read_calibration_cache(self):
    if self.use_cache:
        if os.path.exists(self.cache_file):
            with open(self.cache_file, "rb") as f:
                return f.read()

+
def write_calibration_cache(self, cache):
    with open(self.cache_file, "wb") as f:
        f.write(cache)

+
def make_int8_calibrator(dataloader, cache_file, use_cache, algo_type):
    # Define attributes and member functions for the calibrator class
-    attribute_mapping={'data_loader' : dataloader,
-                       'current_batch_idx' : 0,
-                       'batch_size' : dataloader.batch_size,
-                       'dataset_iterator' : iter(dataloader),
-                       'cache_file' : cache_file,
-                       'use_cache' : use_cache,
-                       'get_batch_size' : get_batch_size,
-                       'get_batch': get_cache_mode_batch if use_cache else get_batch,
-                       'read_calibration_cache' : read_calibration_cache,
-                       'write_calibration_cache' : write_calibration_cache}
+    attribute_mapping = {
+        'data_loader': dataloader,
+        'current_batch_idx': 0,
+        'batch_size': dataloader.batch_size,
+        'dataset_iterator': iter(dataloader),
+        'cache_file': cache_file,
+        'use_cache': use_cache,
+        'get_batch_size': get_batch_size,
+        'get_batch': get_cache_mode_batch if use_cache else get_batch,
+        'read_calibration_cache': read_calibration_cache,
+        'write_calibration_cache': write_calibration_cache
+    }

    # Using type metaclass to construct calibrator class based on algorithm type
    if algo_type == ENTROPY_CALIBRATION:
@@ -62,4 +70,6 @@
    elif algo_type == MINMAX_CALIBRATION:
        return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), attribute_mapping)()
    else:
-        return ValueError("Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION");
+        return ValueError(
+            "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION"
+        )
Reformatting /workspace/py/trtorch/_compile_spec.py
Reformatting /workspace/py/trtorch/_types.py
Reformatting /workspace/py/trtorch/ptq.py
Reformatting /workspace/py/trtorch/_compiler.py
Reformatting /workspace/py/trtorch/logging.py
Reformatting /workspace/py/trtorch/__init__.py
Reformatting /workspace/py/setup.py
ERROR: Some files do not conform to style guidelines

@peri044
Copy link
Collaborator Author

peri044 commented Mar 9, 2021

@narendasan Cache calibrator works though the same code. When use_cache=True, we use get_batch_cache_mode to return None instead of get_batch which usually returns a batch. This was the only difference I found between non-cache and cache calibrators.
Also, I cleaned up and addressed some comments. Posted my questions accordingly for some of the questions.

@narendasan
Copy link
Collaborator

Right but say I have a cache but no dataloader how do I use DataLoaderCalibrator?

@@ -6,7 +6,11 @@
from trtorch._compile_spec import _parse_compile_spec
from trtorch._version import __version__
from types import FunctionType
import tensorrt as trt

ENTROPY_CALIBRATION = trtorch._C.CalibrationAlgo.ENTROPY_CALIBRATION
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to group these in an enum or something called CalibrationAlgo? I think we do something similar for logging

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Is the new class okay or any better way to expose them ?

@narendasan
Copy link
Collaborator

Also seems like we are missing tests

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /workspace/py/trtorch/ptq.py	(original)
+++ /workspace/py/trtorch/ptq.py	(reformatted)
@@ -15,11 +15,14 @@
    LEGACY_CALIBRATION = trtorch._C.CalibrationAlgo.LEGACY_CALIBRATION
    MINMAX_CALIBRATION = trtorch._C.CalibrationAlgo.MINMAX_CALIBRATION

+
def get_cache_mode_batch(self):
    return None

+
def get_batch_size(self):
    return 1
+

def get_batch(self, names):
    if self.current_batch_idx + self.batch_size > self.data_loader.dataset.data.shape[0]:
@@ -32,17 +35,21 @@
        batch = batch[0].to(torch.device('cuda:0'))
    return [batch.data_ptr()]

+
def read_calibration_cache(self):
    if self.use_cache:
        if os.path.exists(self.cache_file):
            with open(self.cache_file, "rb") as f:
                return f.read()

+
def write_calibration_cache(self, cache):
    with open(self.cache_file, "wb") as f:
        f.write(cache)

+
class DataLoaderCalibrator(object):
+
    def __init__(self, dataloader, cache_file, use_cache, algo_type):
        self.algo_type = algo_type
        if use_cache:
@@ -52,16 +59,18 @@
                raise ValueError("use_cache flag is True but cache file not found.")

        # Define attributes and member functions for the calibrator class
-        self.attribute_mapping={'data_loader' : dataloader,
-                               'current_batch_idx' : 0,
-                               'batch_size' : dataloader.batch_size,
-                               'dataset_iterator' : iter(dataloader),
-                               'cache_file' : cache_file,
-                               'use_cache' : use_cache,
-                               'get_batch_size' : get_batch_size,
-                               'get_batch': get_cache_mode_batch if use_cache else get_batch,
-                               'read_calibration_cache' : read_calibration_cache,
-                               'write_calibration_cache' : write_calibration_cache}
+        self.attribute_mapping = {
+            'data_loader': dataloader,
+            'current_batch_idx': 0,
+            'batch_size': dataloader.batch_size,
+            'dataset_iterator': iter(dataloader),
+            'cache_file': cache_file,
+            'use_cache': use_cache,
+            'get_batch_size': get_batch_size,
+            'get_batch': get_cache_mode_batch if use_cache else get_batch,
+            'read_calibration_cache': read_calibration_cache,
+            'write_calibration_cache': write_calibration_cache
+        }

    def __call__(self):
        # Using type metaclass to construct calibrator class based on algorithm type
@@ -74,9 +83,13 @@
        elif self.algo_type == CalibrationAlgo.MINMAX_CALIBRATION:
            return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), self.attribute_mapping)()
        else:
-            return ValueError("Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION");
+            return ValueError(
+                "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION"
+            )
+

class CacheCalibrator(object):
+
    def __init__(self, cache_file, algo_type):
        self.algo_type = algo_type
        if os.path.isfile(cache_file):
@@ -85,12 +98,14 @@
            raise ValueError("Calibration cache file not found at ", cache_file)

        # Define attributes and member functions for the calibrator class
-        self.attribute_mapping={'use_cache' : True,
-                                'cache_file' : cache_file,
-                                'get_batch_size' : get_batch_size,
-                                'get_batch': get_cache_mode_batch,
-                                'read_calibration_cache' : read_calibration_cache,
-                                'write_calibration_cache' : write_calibration_cache}
+        self.attribute_mapping = {
+            'use_cache': True,
+            'cache_file': cache_file,
+            'get_batch_size': get_batch_size,
+            'get_batch': get_cache_mode_batch,
+            'read_calibration_cache': read_calibration_cache,
+            'write_calibration_cache': write_calibration_cache
+        }

    def __call__(self):
        # Using type metaclass to construct calibrator class based on algorithm type
@@ -103,4 +118,6 @@
        elif self.algo_type == CalibrationAlgo.MINMAX_CALIBRATION:
            return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), self.attribute_mapping)()
        else:
-            return ValueError("Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION");
+            return ValueError(
+                "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION"
+            )
Reformatting /workspace/py/trtorch/_compile_spec.py
Reformatting /workspace/py/trtorch/_types.py
Reformatting /workspace/py/trtorch/ptq.py
Reformatting /workspace/py/trtorch/_compiler.py
Reformatting /workspace/py/trtorch/logging.py
Reformatting /workspace/py/trtorch/__init__.py
Reformatting /workspace/py/setup.py
ERROR: Some files do not conform to style guidelines

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool I think we are close. The last things we need for this pr are some unresolved comments, the test cases and we should add to the PTQ article in the docsrc

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /workspace/py/trtorch/ptq.py	(original)
+++ /workspace/py/trtorch/ptq.py	(reformatted)
@@ -16,11 +16,14 @@
    LEGACY_CALIBRATION = trtorch._C.CalibrationAlgo.LEGACY_CALIBRATION
    MINMAX_CALIBRATION = trtorch._C.CalibrationAlgo.MINMAX_CALIBRATION

+
def get_cache_mode_batch(self):
    return None

+
def get_batch_size(self):
    return 1
+

def get_batch(self, names):
    if self.current_batch_idx + self.batch_size > self.data_loader.dataset.data.shape[0]:
@@ -36,18 +39,22 @@
        batch = batch[0].to(self.device)
    return [batch.data_ptr()]

+
def read_calibration_cache(self):
    if self.cache_file and self.use_cache:
        if os.path.exists(self.cache_file):
            with open(self.cache_file, "rb") as f:
                return f.read()

+
def write_calibration_cache(self, cache):
    if self.cache_file:
        with open(self.cache_file, "wb") as f:
            f.write(cache)

+
class DataLoaderCalibrator(object):
+
    def __init__(self, dataloader, **kwargs):
        self.algo_type = kwargs.get("algo_type", trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2)
        self.cache_file = kwargs.get("cache_file", None)
@@ -55,7 +62,8 @@
        self.device = kwargs.get("device", torch.device("cuda:0"))

        if not isinstance(dataloader, torch.utils.data.DataLoader):
-            log(Level.Error, "Dataloader : {} is not a valid instance of torch.utils.data.DataLoader".format(dataloader))
+            log(Level.Error,
+                "Dataloader : {} is not a valid instance of torch.utils.data.DataLoader".format(dataloader))

        if not self.cache_file:
            if self.use_cache:
@@ -67,17 +75,19 @@
                log(Level.Error, "Input cache file is None but use_cache is set to True in INT8 mode.")

        # Define attributes and member functions for the calibrator class
-        self.attribute_mapping={'data_loader' : dataloader,
-                               'current_batch_idx' : 0,
-                               'batch_size' : dataloader.batch_size,
-                               'dataset_iterator' : iter(dataloader),
-                               'cache_file' : self.cache_file,
-                               'device' : self.device,
-                               'use_cache' : self.use_cache,
-                               'get_batch_size' : get_batch_size,
-                               'get_batch': get_cache_mode_batch if self.use_cache else get_batch,
-                               'read_calibration_cache' : read_calibration_cache,
-                               'write_calibration_cache' : write_calibration_cache}
+        self.attribute_mapping = {
+            'data_loader': dataloader,
+            'current_batch_idx': 0,
+            'batch_size': dataloader.batch_size,
+            'dataset_iterator': iter(dataloader),
+            'cache_file': self.cache_file,
+            'device': self.device,
+            'use_cache': self.use_cache,
+            'get_batch_size': get_batch_size,
+            'get_batch': get_cache_mode_batch if self.use_cache else get_batch,
+            'read_calibration_cache': read_calibration_cache,
+            'write_calibration_cache': write_calibration_cache
+        }

    def __call__(self):
        # Using type metaclass to construct calibrator class based on algorithm type
@@ -90,9 +100,14 @@
        elif self.algo_type == CalibrationAlgo.MINMAX_CALIBRATION:
            return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), self.attribute_mapping)()
        else:
-            log(Level.Error, "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION");
+            log(
+                Level.Error,
+                "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION"
+            )
+

class CacheCalibrator(object):
+
    def __init__(self, cache_file, **kwargs):
        self.cache_file = cache_file
        self.algo_type = kwargs.get("algo_type", trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2)
@@ -103,12 +118,14 @@
            log(Level.Error, "Invalid calibration cache file.")

        # Define attributes and member functions for the calibrator class
-        self.attribute_mapping={'use_cache' : True,
-                                'cache_file' : self.cache_file,
-                                'get_batch_size' : get_batch_size,
-                                'get_batch': get_cache_mode_batch,
-                                'read_calibration_cache' : read_calibration_cache,
-                                'write_calibration_cache' : write_calibration_cache}
+        self.attribute_mapping = {
+            'use_cache': True,
+            'cache_file': self.cache_file,
+            'get_batch_size': get_batch_size,
+            'get_batch': get_cache_mode_batch,
+            'read_calibration_cache': read_calibration_cache,
+            'write_calibration_cache': write_calibration_cache
+        }

    def __call__(self):
        # Using type metaclass to construct calibrator class based on algorithm type
@@ -121,4 +138,7 @@
        elif self.algo_type == CalibrationAlgo.MINMAX_CALIBRATION:
            return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), self.attribute_mapping)()
        else:
-            log(Level.Error, "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION");
+            log(
+                Level.Error,
+                "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION"
+            )
Reformatting /workspace/py/trtorch/_compile_spec.py
Reformatting /workspace/py/trtorch/_types.py
Reformatting /workspace/py/trtorch/ptq.py
Reformatting /workspace/py/trtorch/_compiler.py
Reformatting /workspace/py/trtorch/logging.py
Reformatting /workspace/py/trtorch/__init__.py
Reformatting /workspace/py/setup.py
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

@peri044
Copy link
Collaborator Author

peri044 commented Mar 12, 2021

Working on test case and documentation. I will add it soon.

@github-actions github-actions bot added component: tests Issues re: Tests documentation Improvements or additions to documentation labels Mar 14, 2021
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

Reformatting /workspace/py/trtorch/_compile_spec.py
Reformatting /workspace/py/trtorch/_types.py
Reformatting /workspace/py/trtorch/ptq.py
Reformatting /workspace/py/trtorch/_compiler.py
Reformatting /workspace/py/trtorch/logging.py
Reformatting /workspace/py/trtorch/__init__.py
Reformatting /workspace/py/setup.py
--- /workspace/docsrc/conf.py	(original)
+++ /workspace/docsrc/conf.py	(reformatted)
@@ -12,6 +12,7 @@
#
import os
import sys
+
sys.path.append(os.path.join(os.path.dirname(__name__), '../py'))

import sphinx_material
Reformatting /workspace/docsrc/conf.py
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Signed-off-by: Dheeraj Peri <[email protected]>
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /workspace/docsrc/conf.py	(original)
+++ /workspace/docsrc/conf.py	(reformatted)
@@ -12,6 +12,7 @@
#
import os
import sys
+
sys.path.append(os.path.join(os.path.dirname(__name__), '../py'))

import sphinx_material
Reformatting /workspace/docsrc/conf.py
ERROR: Some files do not conform to style guidelines

Signed-off-by: Dheeraj Peri <[email protected]>
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

with torch.no_grad():
idx = 0
for data, labels in testing_dataloader:
data, labels = data.cuda(), labels.cuda(async=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This uses a deprecated api, we should probably update it.

'write_calibration_cache': write_calibration_cache
}

self.calibrator = type('DataLoaderCalibrator', (trt.IInt8EntropyCalibrator,), self.attribute_mapping)()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get that this is probably easier but we should at least try to implement the calibrator like they do in the TRT docs even if it just wraps the dataloadercalibrator functions. Like actually create a class that inherits properly and all


.. code-block:: python

self.testing_dataset = torchvision.datasets.CIFAR10(root='./data',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets remove the selfs here


// using namespace nvinfer1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this line

Signed-off-by: Dheeraj Peri <[email protected]>
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

Reformatting /workspace/docsrc/conf.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_api_dla.py
--- /workspace/tests/py/test_ptq_trt_calibrator.py	(original)
+++ /workspace/tests/py/test_ptq_trt_calibrator.py	(reformatted)
@@ -10,7 +10,9 @@
import torchvision.transforms as transforms
from model_test_case import ModelTestCase

+
class TRTEntropyCalibrator(trt.IInt8EntropyCalibrator2):
+
    def __init__(self, dataloader, **kwargs):
        trt.IInt8EntropyCalibrator2.__init__(self)

@@ -40,7 +42,6 @@
            batch = batch[0].to(self.device)
        return [batch.data_ptr()]

-
    def read_calibration_cache(self):
        # If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
        if self.use_cache:
@@ -51,6 +52,7 @@
        if self.cache_file:
            with open(self.cache_file, "wb") as f:
                f.write(cache)
+

class TestAccuracy(ModelTestCase):

Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@narendasan narendasan merged commit b4b12a1 into master Mar 17, 2021
@narendasan narendasan deleted the int8_py branch March 17, 2021 20:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component: api [C++] Issues re: C++ API component: api [Python] Issues re: Python API component: tests Issues re: Tests documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

PTQ in python API
2 participants