@@ -35,7 +35,6 @@ struct Resize : public torch::data::transforms::TensorTransform<torch::Tensor> {
35
35
torch::jit::Module compile_int8_model (const std::string& data_dir, torch::jit::Module& mod) {
36
36
auto calibration_dataset = datasets::CIFAR10 (data_dir, datasets::CIFAR10::Mode::kTest )
37
37
.use_subset (320 )
38
- .map (Resize ({300 , 300 }))
39
38
.map (torch::data::transforms::Normalize<>({0.4914 , 0.4822 , 0.4465 },
40
39
{0.2023 , 0.1994 , 0.2010 }))
41
40
.map (torch::data::transforms::Stack<>());
@@ -48,7 +47,7 @@ torch::jit::Module compile_int8_model(const std::string& data_dir, torch::jit::M
48
47
auto calibrator = trtorch::ptq::make_int8_calibrator (std::move (calibration_dataloader), calibration_cache_file, true );
49
48
50
49
51
- std::vector<std::vector<int64_t >> input_shape = {{32 , 3 , 300 , 300 }};
50
+ std::vector<std::vector<int64_t >> input_shape = {{32 , 3 , 32 , 32 }};
52
51
// / Configure settings for compilation
53
52
auto extra_info = trtorch::ExtraInfo ({input_shape});
54
53
// / Set operating precision to INT8
@@ -99,7 +98,6 @@ int main(int argc, const char* argv[]) {
99
98
100
99
// / Dataloader moved into calibrator so need another for inference
101
100
auto eval_dataset = datasets::CIFAR10 (data_dir, datasets::CIFAR10::Mode::kTest )
102
- .map (Resize ({300 , 300 }))
103
101
.map (torch::data::transforms::Normalize<>({0.4914 , 0.4822 , 0.4465 },
104
102
{0.2023 , 0.1994 , 0.2010 }))
105
103
.map (torch::data::transforms::Stack<>());
@@ -131,7 +129,7 @@ int main(int argc, const char* argv[]) {
131
129
if (images.sizes ()[0 ] < 32 ) {
132
130
// / To handle smaller batches util Optimization profiles work with Int8
133
131
auto diff = 32 - images.sizes ()[0 ];
134
- auto img_padding = torch::zeros ({diff, 3 , 300 , 300 }, {torch::kCUDA });
132
+ auto img_padding = torch::zeros ({diff, 3 , 32 , 32 }, {torch::kCUDA });
135
133
auto target_padding = torch::zeros ({diff}, {torch::kCUDA });
136
134
images = torch::cat ({images, img_padding}, 0 );
137
135
targets = torch::cat ({targets, target_padding}, 0 );
@@ -152,7 +150,7 @@ int main(int argc, const char* argv[]) {
152
150
std::cout << " Accuracy of quantized model on test set: " << 100 * (correct / total) << " %" << std::endl;
153
151
154
152
// / Time execution in JIT-FP32 and TRT-INT8
155
- std::vector<std::vector<int64_t >> dims = {{32 , 3 , 300 , 300 }};
153
+ std::vector<std::vector<int64_t >> dims = {{32 , 3 , 32 , 32 }};
156
154
157
155
auto jit_runtimes = benchmark_module (mod, dims[0 ]);
158
156
print_avg_std_dev (" JIT model FP32" , jit_runtimes, dims[0 ][0 ]);
0 commit comments