@@ -42,24 +42,24 @@ torch::jit::Module compile_int8_model(const std::string& data_dir, torch::jit::M
42
42
43
43
auto calibrator = trtorch::ptq::make_int8_calibrator (std::move (calibration_dataloader), calibration_cache_file, true );
44
44
45
- std::vector<std::vector<int64_t >> input_shape = {{32 , 3 , 32 , 32 }};
45
+ std::vector<trtorch::CompileSpec::Input> inputs = {
46
+ trtorch::CompileSpec::Input (std::vector<int64_t >({32 , 3 , 32 , 32 }), trtorch::CompileSpec::DataType::kFloat )};
46
47
// / Configure settings for compilation
47
- auto compile_spec = trtorch::CompileSpec ({input_shape} );
48
+ auto compile_spec = trtorch::CompileSpec (inputs );
48
49
// / Set operating precision to INT8
49
- compile_spec.enable_precisions .insert (torch::kI8 );
50
+ compile_spec.enabled_precisions .insert (torch::kF16 );
51
+ compile_spec.enabled_precisions .insert (torch::kI8 );
50
52
// / Use the TensorRT Entropy Calibrator
51
53
compile_spec.ptq_calibrator = calibrator;
52
54
// / Set max batch size for the engine
53
55
compile_spec.max_batch_size = 32 ;
54
56
// / Set a larger workspace
55
57
compile_spec.workspace_size = 1 << 28 ;
56
58
57
- mod.eval ();
58
-
59
59
#ifdef SAVE_ENGINE
60
60
std::cout << " Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
61
61
auto engine = trtorch::ConvertGraphToTRTEngine (mod, " forward" , compile_spec);
62
- std::ofstream out (" /tmp/engine_converted_from_jit .trt" );
62
+ std::ofstream out (" /tmp/int8_engine_converted_from_jit .trt" );
63
63
out << engine;
64
64
out.close ();
65
65
#endif
@@ -86,60 +86,53 @@ int main(int argc, const char* argv[]) {
86
86
return -1 ;
87
87
}
88
88
89
+ mod.eval ();
90
+
89
91
// / Create the calibration dataset
90
92
const std::string data_dir = std::string (argv[2 ]);
91
- auto trt_mod = compile_int8_model (data_dir, mod);
92
93
93
94
// / Dataloader moved into calibrator so need another for inference
94
95
auto eval_dataset = datasets::CIFAR10 (data_dir, datasets::CIFAR10::Mode::kTest )
96
+ .use_subset (3200 )
95
97
.map (torch::data::transforms::Normalize<>({0.4914 , 0.4822 , 0.4465 }, {0.2023 , 0.1994 , 0.2010 }))
96
98
.map (torch::data::transforms::Stack<>());
97
99
auto eval_dataloader = torch::data::make_data_loader (
98
100
std::move (eval_dataset), torch::data::DataLoaderOptions ().batch_size (32 ).workers (2 ));
99
101
100
102
// / Check the FP32 accuracy in JIT
101
- float correct = 0.0 , total = 0.0 ;
103
+ torch::Tensor jit_correct = torch::zeros ({ 1 }, {torch:: kCUDA }), jit_total = torch::zeros ({ 1 }, {torch:: kCUDA }) ;
102
104
for (auto batch : *eval_dataloader) {
103
105
auto images = batch.data .to (torch::kCUDA );
104
106
auto targets = batch.target .to (torch::kCUDA );
105
107
106
108
auto outputs = mod.forward ({images});
107
109
auto predictions = std::get<1 >(torch::max (outputs.toTensor (), 1 , false ));
108
110
109
- total += targets.sizes ()[0 ];
110
- correct += torch::sum (torch::eq (predictions, targets)). item (). toFloat ( );
111
+ jit_total += targets.sizes ()[0 ];
112
+ jit_correct += torch::sum (torch::eq (predictions, targets));
111
113
}
112
- std::cout << " Accuracy of JIT model on test set: " << 100 * (correct / total) << " %" << std::endl;
114
+ torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100 ;
115
+
116
+ // / Compile Graph
117
+ auto trt_mod = compile_int8_model (data_dir, mod);
113
118
114
119
// / Check the INT8 accuracy in TRT
115
- correct = 0.0 ;
116
- total = 0.0 ;
120
+ torch::Tensor trt_correct = torch::zeros ({1 }, {torch::kCUDA }), trt_total = torch::zeros ({1 }, {torch::kCUDA });
117
121
for (auto batch : *eval_dataloader) {
118
122
auto images = batch.data .to (torch::kCUDA );
119
123
auto targets = batch.target .to (torch::kCUDA );
120
124
121
- if (images.sizes ()[0 ] < 32 ) {
122
- // / To handle smaller batches util Optimization profiles work with Int8
123
- auto diff = 32 - images.sizes ()[0 ];
124
- auto img_padding = torch::zeros ({diff, 3 , 32 , 32 }, {torch::kCUDA });
125
- auto target_padding = torch::zeros ({diff}, {torch::kCUDA });
126
- images = torch::cat ({images, img_padding}, 0 );
127
- targets = torch::cat ({targets, target_padding}, 0 );
128
- }
129
-
130
125
auto outputs = trt_mod.forward ({images});
131
126
auto predictions = std::get<1 >(torch::max (outputs.toTensor (), 1 , false ));
132
127
predictions = predictions.reshape (predictions.sizes ()[0 ]);
133
128
134
- if (predictions.sizes ()[0 ] != targets.sizes ()[0 ]) {
135
- // / To handle smaller batches util Optimization profiles work with Int8
136
- predictions = predictions.slice (0 , 0 , targets.sizes ()[0 ]);
137
- }
138
-
139
- total += targets.sizes ()[0 ];
140
- correct += torch::sum (torch::eq (predictions, targets)).item ().toFloat ();
129
+ trt_total += targets.sizes ()[0 ];
130
+ trt_correct += torch::sum (torch::eq (predictions, targets)).item ().toFloat ();
141
131
}
142
- std::cout << " Accuracy of quantized model on test set: " << 100 * (correct / total) << " %" << std::endl;
132
+ torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100 ;
133
+
134
+ std::cout << " Accuracy of JIT model on test set: " << jit_accuracy.item ().toFloat () << " %" << std::endl;
135
+ std::cout << " Accuracy of quantized model on test set: " << trt_accuracy.item ().toFloat () << " %" << std::endl;
143
136
144
137
// / Time execution in JIT-FP32 and TRT-INT8
145
138
std::vector<std::vector<int64_t >> dims = {{32 , 3 , 32 , 32 }};
0 commit comments