-
Notifications
You must be signed in to change notification settings - Fork 365
Adds support for PTQ through the PyTorch to_backend api. #398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
INT8 calibrators Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
tests Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
Reformatting /workspace/docsrc/conf.py
--- /workspace/tests/py/test_ptq_trt_calibrator.py (original)
+++ /workspace/tests/py/test_ptq_trt_calibrator.py (reformatted)
@@ -10,7 +10,9 @@
import torchvision.transforms as transforms
from model_test_case import ModelTestCase
+
class TRTEntropyCalibrator(trt.IInt8EntropyCalibrator2):
+
def __init__(self, dataloader, **kwargs):
trt.IInt8EntropyCalibrator2.__init__(self)
@@ -40,7 +42,6 @@
batch = batch[0].to(self.device)
return [batch.data_ptr()]
-
def read_calibration_cache(self):
# If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
if self.use_cache:
@@ -51,6 +52,7 @@
if self.cache_file:
with open(self.cache_file, "wb") as f:
f.write(cache)
+
class TestAccuracy(ModelTestCase):
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
Reformatting /workspace/docsrc/conf.py
--- /workspace/tests/py/test_ptq_trt_calibrator.py (original)
+++ /workspace/tests/py/test_ptq_trt_calibrator.py (reformatted)
@@ -10,7 +10,9 @@
import torchvision.transforms as transforms
from model_test_case import ModelTestCase
+
class TRTEntropyCalibrator(trt.IInt8EntropyCalibrator2):
+
def __init__(self, dataloader, **kwargs):
trt.IInt8EntropyCalibrator2.__init__(self)
@@ -40,7 +42,6 @@
batch = batch[0].to(self.device)
return [batch.data_ptr()]
-
def read_calibration_cache(self):
# If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
if self.use_cache:
@@ -51,6 +52,7 @@
if self.cache_file:
with open(self.cache_file, "wb") as f:
f.write(cache)
+
class TestAccuracy(ModelTestCase):
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_multi_gpu.py
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
with torch.no_grad(): | ||
idx = 0 | ||
for data, labels in testing_dataloader: | ||
data, labels = data.cuda(), labels.cuda(non_blocking=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data.to(device) maybe to avoid warnings ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this shouldnt throw warnings since its using the new api. the .cuda api wasnt deprecated it was just the async flag
|
||
def test_suite(): | ||
suite = unittest.TestSuite() | ||
suite.addTest(TestAccuracy.parametrize(TestAccuracy, model=torch.jit.load('./trained_vgg16.jit.pt'))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add this comment here and also to test_ptq_dataloader_calibrator.py
as well ?
# You need a pre-trained VGG cifar10 model to run this test. Please follow instructions at
# https://github.com/NVIDIA/TRTorch/tree/master/cpp/ptq/training/vgg16 to export this model.
I added it to test_ptq_trt_calibrator.py
but forgot it at other place.
@@ -94,10 +94,18 @@ struct CompileSpec : torch::CustomClassHolder { | |||
input_ranges.push_back(*ir); | |||
} | |||
|
|||
int64_t getPTQCalibratorHandle() { | |||
return (int64_t)ptq_calibrator; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you let me know why we do this cast to int64_t and type cast it back in the setPTQCalibratorViaHandle
call ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TorchBind cannot handle pointers as arguments so this was the cheapest way to get a pointer added to the struct. we get the int64_t casted pointer from the original struct and there is void setPTQCalibratorViaHandle(int64_t handle)
to set the pointer from an int64_t in a struct owned by torchbind
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both these functions dont get exposed to the user, they are purely used internally
test Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
e19fe2e
to
088d586
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
Reformatting /workspace/docsrc/conf.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
--- /workspace/tests/py/test_ptq_dataloader_calibrator.py (original)
+++ /workspace/tests/py/test_ptq_dataloader_calibrator.py (reformatted)
@@ -82,7 +82,7 @@
def test_suite():
suite = unittest.TestSuite()
# You need a pre-trained VGG cifar10 model to run this test. Please follow instructions at
-# https://github.com/NVIDIA/TRTorch/tree/master/cpp/ptq/training/vgg16 to export this model.
+ # https://github.com/NVIDIA/TRTorch/tree/master/cpp/ptq/training/vgg16 to export this model.
suite.addTest(TestAccuracy.parametrize(TestAccuracy, model=torch.jit.load('./trained_vgg16.jit.pt')))
return suite
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
Description
Adds support for PTQ through the PyTorch to_backend api.
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: