Skip to content

Commit 7efa11d

Browse files
committed
fix(//cpp/ptq): fixing bad accuracy in just the example code
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent b4a2dd6 commit 7efa11d

File tree

3 files changed

+39
-33
lines changed

3 files changed

+39
-33
lines changed

Diff for: cpp/ptq/main.cpp

+23-30
Original file line numberDiff line numberDiff line change
@@ -42,24 +42,24 @@ torch::jit::Module compile_int8_model(const std::string& data_dir, torch::jit::M
4242

4343
auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true);
4444

45-
std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
45+
std::vector<trtorch::CompileSpec::Input> inputs = {
46+
trtorch::CompileSpec::Input(std::vector<int64_t>({32, 3, 32, 32}), trtorch::CompileSpec::DataType::kFloat)};
4647
/// Configure settings for compilation
47-
auto compile_spec = trtorch::CompileSpec({input_shape});
48+
auto compile_spec = trtorch::CompileSpec(inputs);
4849
/// Set operating precision to INT8
49-
compile_spec.enable_precisions.insert(torch::kI8);
50+
compile_spec.enabled_precisions.insert(torch::kF16);
51+
compile_spec.enabled_precisions.insert(torch::kI8);
5052
/// Use the TensorRT Entropy Calibrator
5153
compile_spec.ptq_calibrator = calibrator;
5254
/// Set max batch size for the engine
5355
compile_spec.max_batch_size = 32;
5456
/// Set a larger workspace
5557
compile_spec.workspace_size = 1 << 28;
5658

57-
mod.eval();
58-
5959
#ifdef SAVE_ENGINE
6060
std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
6161
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", compile_spec);
62-
std::ofstream out("/tmp/engine_converted_from_jit.trt");
62+
std::ofstream out("/tmp/int8_engine_converted_from_jit.trt");
6363
out << engine;
6464
out.close();
6565
#endif
@@ -86,60 +86,53 @@ int main(int argc, const char* argv[]) {
8686
return -1;
8787
}
8888

89+
mod.eval();
90+
8991
/// Create the calibration dataset
9092
const std::string data_dir = std::string(argv[2]);
91-
auto trt_mod = compile_int8_model(data_dir, mod);
9293

9394
/// Dataloader moved into calibrator so need another for inference
9495
auto eval_dataset = datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest)
96+
.use_subset(3200)
9597
.map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, {0.2023, 0.1994, 0.2010}))
9698
.map(torch::data::transforms::Stack<>());
9799
auto eval_dataloader = torch::data::make_data_loader(
98100
std::move(eval_dataset), torch::data::DataLoaderOptions().batch_size(32).workers(2));
99101

100102
/// Check the FP32 accuracy in JIT
101-
float correct = 0.0, total = 0.0;
103+
torch::Tensor jit_correct = torch::zeros({1}, {torch::kCUDA}), jit_total = torch::zeros({1}, {torch::kCUDA});
102104
for (auto batch : *eval_dataloader) {
103105
auto images = batch.data.to(torch::kCUDA);
104106
auto targets = batch.target.to(torch::kCUDA);
105107

106108
auto outputs = mod.forward({images});
107109
auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false));
108110

109-
total += targets.sizes()[0];
110-
correct += torch::sum(torch::eq(predictions, targets)).item().toFloat();
111+
jit_total += targets.sizes()[0];
112+
jit_correct += torch::sum(torch::eq(predictions, targets));
111113
}
112-
std::cout << "Accuracy of JIT model on test set: " << 100 * (correct / total) << "%" << std::endl;
114+
torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100;
115+
116+
/// Compile Graph
117+
auto trt_mod = compile_int8_model(data_dir, mod);
113118

114119
/// Check the INT8 accuracy in TRT
115-
correct = 0.0;
116-
total = 0.0;
120+
torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA});
117121
for (auto batch : *eval_dataloader) {
118122
auto images = batch.data.to(torch::kCUDA);
119123
auto targets = batch.target.to(torch::kCUDA);
120124

121-
if (images.sizes()[0] < 32) {
122-
/// To handle smaller batches util Optimization profiles work with Int8
123-
auto diff = 32 - images.sizes()[0];
124-
auto img_padding = torch::zeros({diff, 3, 32, 32}, {torch::kCUDA});
125-
auto target_padding = torch::zeros({diff}, {torch::kCUDA});
126-
images = torch::cat({images, img_padding}, 0);
127-
targets = torch::cat({targets, target_padding}, 0);
128-
}
129-
130125
auto outputs = trt_mod.forward({images});
131126
auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false));
132127
predictions = predictions.reshape(predictions.sizes()[0]);
133128

134-
if (predictions.sizes()[0] != targets.sizes()[0]) {
135-
/// To handle smaller batches util Optimization profiles work with Int8
136-
predictions = predictions.slice(0, 0, targets.sizes()[0]);
137-
}
138-
139-
total += targets.sizes()[0];
140-
correct += torch::sum(torch::eq(predictions, targets)).item().toFloat();
129+
trt_total += targets.sizes()[0];
130+
trt_correct += torch::sum(torch::eq(predictions, targets)).item().toFloat();
141131
}
142-
std::cout << "Accuracy of quantized model on test set: " << 100 * (correct / total) << "%" << std::endl;
132+
torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100;
133+
134+
std::cout << "Accuracy of JIT model on test set: " << jit_accuracy.item().toFloat() << "%" << std::endl;
135+
std::cout << "Accuracy of quantized model on test set: " << trt_accuracy.item().toFloat() << "%" << std::endl;
143136

144137
/// Time execution in JIT-FP32 and TRT-INT8
145138
std::vector<std::vector<int64_t>> dims = {{32, 3, 32, 32}};

Diff for: tests/accuracy/BUILD

+10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@ filegroup(
33
srcs = glob(["**/*.jit.pt"]),
44
)
55

6+
filegroup(
7+
name = "data",
8+
srcs = glob(["datasets/**/*"])
9+
)
10+
611
test_suite(
712
name = "aarch64_accuracy_tests",
813
tests = [
@@ -28,6 +33,7 @@ cc_test(
2833
srcs = ["test_int8_accuracy.cpp"],
2934
data = [
3035
":jit_models",
36+
":data"
3137
],
3238
deps = [
3339
":accuracy_test",
@@ -40,6 +46,7 @@ cc_test(
4046
srcs = ["test_fp16_accuracy.cpp"],
4147
data = [
4248
":jit_models",
49+
":data"
4350
],
4451
deps = [
4552
":accuracy_test",
@@ -52,6 +59,7 @@ cc_test(
5259
srcs = ["test_fp32_accuracy.cpp"],
5360
data = [
5461
":jit_models",
62+
":data"
5563
],
5664
deps = [
5765
":accuracy_test",
@@ -64,6 +72,7 @@ cc_test(
6472
srcs = ["test_dla_int8_accuracy.cpp"],
6573
data = [
6674
":jit_models",
75+
":data"
6776
],
6877
deps = [
6978
":accuracy_test",
@@ -76,6 +85,7 @@ cc_test(
7685
srcs = ["test_dla_fp16_accuracy.cpp"],
7786
data = [
7887
":jit_models",
88+
":data"
7989
],
8090
deps = [
8191
":accuracy_test",

Diff for: tests/accuracy/test_int8_accuracy.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@ TEST_P(AccuracyTests, INT8AccuracyIsClose) {
1414

1515
std::string calibration_cache_file = "/tmp/vgg16_TRT_ptq_calibration.cache";
1616

17-
auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true);
17+
auto calibrator =
18+
trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, false);
1819
// auto calibrator = trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file);
1920

20-
std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
21+
std::vector<trtorch::CompileSpec::Input> inputs = {
22+
trtorch::CompileSpec::Input(std::vector<int64_t>({32, 3, 32, 32}), trtorch::CompileSpec::DataType::kFloat)};
2123
// Configure settings for compilation
22-
auto compile_spec = trtorch::CompileSpec({input_shape});
24+
auto compile_spec = trtorch::CompileSpec(inputs);
2325
// Set operating precision to INT8
26+
compile_spec.enabled_precisions.insert(torch::kF16);
2427
compile_spec.enabled_precisions.insert(torch::kI8);
2528
// Use the TensorRT Entropy Calibrator
2629
compile_spec.ptq_calibrator = calibrator;

0 commit comments

Comments
 (0)