Skip to content

Commit b62df15

Browse files
authored
Merge pull request #1242 from pytorch/ptq-tutorial-compile_spec-fix
Update PTQ example to fix new compile_spec requirements
2 parents d0e471f + 12a1739 commit b62df15

File tree

1 file changed

+13
-20
lines changed

1 file changed

+13
-20
lines changed

docsrc/tutorials/ptq.rst

+13-20
Original file line numberDiff line numberDiff line change
@@ -167,19 +167,16 @@ a TensorRT calibrator by providing desired configuration. The following code dem
167167
algo_type=torch_tensorrt.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
168168
device=torch.device('cuda:0'))
169169
170-
compile_spec = {
171-
"inputs": [torch_tensorrt.Input((1, 3, 32, 32))],
172-
"enabled_precisions": {torch.float, torch.half, torch.int8},
173-
"calibrator": calibrator,
174-
"device": {
175-
"device_type": torch_tensorrt.DeviceType.GPU,
176-
"gpu_id": 0,
177-
"dla_core": 0,
178-
"allow_gpu_fallback": False,
179-
"disable_tf32": False
180-
}
181-
}
182-
trt_mod = torch_tensorrt.compile(model, compile_spec)
170+
trt_mod = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input((1, 3, 32, 32))],
171+
enabled_precisions={torch.float, torch.half, torch.int8},
172+
calibrator=calibrator,
173+
device={
174+
"device_type": torch_tensorrt.DeviceType.GPU,
175+
"gpu_id": 0,
176+
"dla_core": 0,
177+
"allow_gpu_fallback": False,
178+
"disable_tf32": False
179+
})
183180
184181
In the cases where there is a pre-existing calibration cache file that users want to use, ``CacheCalibrator`` can be used without any dataloaders. The following example demonstrates how
185182
to use ``CacheCalibrator`` to use in INT8 mode.
@@ -188,13 +185,9 @@ to use ``CacheCalibrator`` to use in INT8 mode.
188185
189186
calibrator = torch_tensorrt.ptq.CacheCalibrator("./calibration.cache")
190187
191-
compile_settings = {
192-
"inputs": [torch_tensorrt.Input([1, 3, 32, 32])],
193-
"enabled_precisions": {torch.float, torch.half, torch.int8},
194-
"calibrator": calibrator,
195-
}
196-
197-
trt_mod = torch_tensorrt.compile(model, compile_settings)
188+
trt_mod = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input([1, 3, 32, 32])],
189+
enabled_precisions={torch.float, torch.half, torch.int8},
190+
calibrator=calibrator)
198191
199192
If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient.
200193
For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_dataloader_calibrator.py

0 commit comments

Comments
 (0)