Skip to content

Commit 749048c

Browse files
committed
fix: Fix deepcopy issues of PTQ calibrators
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 13cc024 commit 749048c

File tree

4 files changed

+15
-17
lines changed

4 files changed

+15
-17
lines changed

noxfile.py

-2
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,6 @@ def run_l2_trt_compatibility_tests(session):
342342
if not USE_HOST_DEPS:
343343
install_deps(session)
344344
install_torch_trt(session)
345-
download_models(session)
346-
train_model(session)
347345
run_trt_compatibility_tests(session)
348346
cleanup(session)
349347

py/torch_tensorrt/ptq.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ def write_calibration_cache(self, cache):
5555
else:
5656
return b""
5757

58+
# deepcopy (which involves pickling) is performed on the compile_spec internally during compilation.
59+
# We register this __reduce__ function for pickler to identity the calibrator object returned by DataLoaderCalibrator during deepcopy.
60+
# This should be the object's local name relative to the module https://docs.python.org/3/library/pickle.html#object.__reduce__
61+
def __reduce__(self):
62+
return self.__class__.__name__
5863

5964
class DataLoaderCalibrator(object):
6065
"""
@@ -114,24 +119,25 @@ def __new__(cls, *args, **kwargs):
114119
"get_batch": get_cache_mode_batch if use_cache else get_batch,
115120
"read_calibration_cache": read_calibration_cache,
116121
"write_calibration_cache": write_calibration_cache,
122+
"__reduce__": __reduce__ # used when you deepcopy the DataLoaderCalibrator object
117123
}
118124

119125
# Using type metaclass to construct calibrator class based on algorithm type
120126
if algo_type == CalibrationAlgo.ENTROPY_CALIBRATION:
121127
return type(
122-
"DataLoaderCalibrator", (_C.IInt8EntropyCalibrator,), attribute_mapping
128+
"Int8EntropyCalibrator", (_C.IInt8EntropyCalibrator,), attribute_mapping
123129
)()
124130
elif algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2:
125131
return type(
126-
"DataLoaderCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping
132+
"Int8EntropyCalibrator2", (_C.IInt8EntropyCalibrator2,), attribute_mapping
127133
)()
128134
elif algo_type == CalibrationAlgo.LEGACY_CALIBRATION:
129135
return type(
130-
"DataLoaderCalibrator", (_C.IInt8LegacyCalibrator,), attribute_mapping
136+
"Int8LegacyCalibrator", (_C.IInt8LegacyCalibrator,), attribute_mapping
131137
)()
132138
elif algo_type == CalibrationAlgo.MINMAX_CALIBRATION:
133139
return type(
134-
"DataLoaderCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping
140+
"Int8MinMaxCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping
135141
)()
136142
else:
137143
log(

py/torch_tensorrt/ts/_compile_spec.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,7 @@ def _parse_input_signature(input_signature: Any):
226226

227227
def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
228228
# TODO: Use deepcopy to support partial compilation of collections
229-
compile_spec = {}
230-
for k, v in compile_spec_.items():
231-
if k != "calibrator":
232-
compile_spec[k] = deepcopy(v)
233-
else:
234-
compile_spec[k] = v
235-
229+
compile_spec = deepcopy(compile_spec_)
236230
info = _ts_C.CompileSpec()
237231

238232
if len(compile_spec["inputs"]) > 0:

tests/py/ptq/test_ptq_dataloader_calibrator.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,6 @@ def test_compile_script(self):
8181
device=torch.device("cuda:0"),
8282
)
8383

84-
fp32_test_acc = compute_accuracy(self.testing_dataloader, self.model)
85-
log(Level.Info, "[Pyt FP32] Test Acc: {:.2f}%".format(100 * fp32_test_acc))
86-
8784
compile_spec = {
8885
"inputs": [torchtrt.Input([1, 3, 32, 32])],
8986
"enabled_precisions": {torch.float, torch.int8},
@@ -96,8 +93,11 @@ def test_compile_script(self):
9693
"allow_gpu_fallback": False,
9794
},
9895
}
99-
10096
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
97+
98+
fp32_test_acc = compute_accuracy(self.testing_dataloader, self.model)
99+
log(Level.Info, "[Pyt FP32] Test Acc: {:.2f}%".format(100 * fp32_test_acc))
100+
101101
int8_test_acc = compute_accuracy(self.testing_dataloader, trt_mod)
102102
log(Level.Info, "[TRT INT8] Test Acc: {:.2f}%".format(100 * int8_test_acc))
103103
acc_diff = fp32_test_acc - int8_test_acc

0 commit comments

Comments
 (0)