Skip to content

Commit dd443a6

Browse files
committed
feat(//core/quantization): skeleton of INT8 PTQ calibrator
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent aef6003 commit dd443a6

File tree

10 files changed

+221
-11
lines changed

10 files changed

+221
-11
lines changed

Diff for: .gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,6 @@ experiments/
1313
py/build/
1414
py/tmp/
1515
py/.eggs
16-
.vscode/
16+
.vscode/
17+
.DS_Store
18+
._DS_Store

Diff for: core/conversion/conversionctx/ConversionCtx.cpp

+10-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
1919
<< "\n Avg Timing Iterations: " << s.num_avg_timing_iters \
2020
<< "\n Max Workspace Size: " << s.workspace_size \
2121
<< "\n Device Type: " << s.device \
22-
<< "\n Engine Capability: " << s.capability;
22+
<< "\n Engine Capability: " << s.capability \
23+
<< "\n Calibrator Created: " << s.calibrator ? true : false;
2324
return os;
2425
}
2526

@@ -36,13 +37,17 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
3637

3738
switch(settings.op_precision) {
3839
case nvinfer1::DataType::kHALF:
40+
TRTORCH_CHECK(builder->platformHasFastFp16(), "Requested inference in FP16 but platform does support FP16");
3941
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
4042
input_type = nvinfer1::DataType::kHALF;
4143
break;
42-
// case nvinfer1::DataType::kINT8:
43-
// cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
44-
// input_type = nvinfer1::DataType::kFLOAT;
45-
// break;
44+
case nvinfer1::DataType::kINT8:
45+
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does support INT8");
46+
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
47+
input_type = nvinfer1::DataType::kINT8;
48+
// If the calibrator is nullptr then TRT will use default quantization
49+
cfg->setInt8Calibrator(settings.calibrator);
50+
break;
4651
case nvinfer1::DataType::kFLOAT:
4752
default:
4853
input_type = nvinfer1::DataType::kFLOAT;

Diff for: core/conversion/conversionctx/ConversionCtx.h

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct BuilderSettings {
2424
bool allow_gpu_fallback = true;
2525
nvinfer1::DeviceType device = nvinfer1::DeviceType::kGPU;
2626
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT;
27+
nvinfer1::IInt8Calibrator* calibrator = nullptr;
2728
uint64_t num_min_timing_iters = 2;
2829
uint64_t num_avg_timing_iters = 1;
2930
uint64_t workspace_size = 0;

Diff for: core/quantization/BUILD

Whitespace-only changes.

Diff for: core/quantization/TRTEntropyCalibrator.cpp

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#include "core/util/prelude.h"
2+
#include "core/quantization/quantization.h"
3+
4+
namespace trtorch {
5+
namespace core {
6+
namespace quantization {
7+
8+
Int8CalibratorImpl::Int8CalibratorImpl(QuantizationSettings&& settings)
9+
: dataset_(std::move(settings.calibration_dataset),
10+
cache_file_path_(settings.calibration_cache_file),
11+
use_cache_(settings.use_cache) {
12+
buffers_.reserve(dataset_.size);
13+
14+
}
15+
16+
int Int8CalibratorImpl::GetBatchSize() const {
17+
18+
}
19+
20+
bool Int8CalibratorImpl::GetBatch(void* bindings[], const char* names[], int num_bindings) {
21+
if (!is_next_batch) {
22+
return false;
23+
}
24+
25+
for (size_t i = 0; i < num_bindings; i++) {
26+
auto batch = next_binding_batch(names[i]);
27+
batch = batch.to(at::kCUDA).contiguous();
28+
bindings[i] = batch.data_ptr();
29+
}
30+
return true;
31+
}
32+
33+
const void* Int8CalibratorImpl::ReadCalibrationCache(size_t& length) {
34+
cache_.clear();
35+
std::ifstream cache_file(cache_file_path_, std::ios::binary);
36+
cache_file >> std::noskipws;
37+
if (use_cache && cache_file.good()) {
38+
std::copy(std::istream_iterator<char>(input),
39+
std::istream_iterator<char>(),
40+
std::back_inserter(cache_));
41+
}
42+
cache_size_ = cache_.size();
43+
return cache_size ? cache_.data() : nullptr;
44+
}
45+
46+
void Int8CalibratorImpl::WriteCalibrationCache(const void* cache, size_t length) {
47+
std::ofstream cache_file(cache_file_path_, std::ios::binary);
48+
cache_file.write(reinterpret_cast<const char*>(cache_), cache_size_);
49+
}
50+
51+
nvinfer1::IInt8Calibrator create_int8_calibrator(QuantizationSettings settings) {
52+
auto calibrator_impl = Int8CalibratorImpl(settings);
53+
switch(settings.calibrator_type) {
54+
case CalibratorKind::kMinMax:
55+
return TRTInt8MinMaxCalibrator(std::move(calibrator_impl));
56+
case CalibratorKind::kEntropy:
57+
default:
58+
return TRTInt8EntropyCalibrator(std::move(calibrator_impl));
59+
}
60+
}
61+
62+
} // quantization
63+
} // core
64+
} // trtorch

Diff for: core/quantization/quantization.h

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#pragma once
2+
#include "ATen/tensor.h"
3+
#include "NvInfer.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace quantization {
8+
9+
enum class CalibratorKind {
10+
kEntropy,
11+
kMinMax,
12+
}
13+
14+
in conveter or whatever
15+
in order given std::vector<at::Tensor> -> map<input_name, at::Tensor>
16+
17+
struct QuantizationSettings {
18+
CalibratorKind calibrator_type = CalibratorKind::kEntropy;
19+
const std::string& calibration_cache_file = "";
20+
bool use_cache = false;
21+
std::unordered_map<std::string, at::Tensor> calibration_dataset;
22+
};
23+
24+
class CalibrationBatchStream {
25+
26+
};
27+
28+
class Int8CalibratorImpl {
29+
public:
30+
TRTInt8CalibratorImpl(QuantizationSettings& settings);
31+
int GetBatchSize() const;
32+
bool GetBatch(void* bindings[], const char* names[], int num_bindings);
33+
const void* ReadCalibrationCache(size_t& length);
34+
void WriteCalibrationCache(const void* cache, size_t length);
35+
private:
36+
std::unordered_map<std::string, at::Tensor> dataset_;
37+
const std::string& cache_file_path_;
38+
std::vector<char> cache_;
39+
bool use_cache_;
40+
size_t cache_size_ = 0;
41+
};
42+
43+
class TRTInt8EntropyCalibrator : nvinfer1::IInt8EntropyCalibrator2 {
44+
public:
45+
TRTInt8EntropyCalibrator(Int8CalibratorImpl impl) : impl_(impl) {}
46+
int getBatchSize() const override {return impl_.GetBatchSize();}
47+
bool getBatch(void* bindings[], const char* names[], int nbBindings) override {return impl_.GetBatch(bindings, names, nbBindings)};
48+
const void* readCalibrationCache(size_t& length) override {return impl_.ReadCalibrationCache(size_t& length)};
49+
void writeCalibrationCache(const void* cache, size_t length) override {impl_.WriteCalibrationCache(const void* cache, size_t length)};
50+
private:
51+
Int8CalibratorImpl impl_;
52+
};
53+
54+
class TRTInt8MinMaxCalibrator : nvinfer1::IInt8MinMaxCalibrator {
55+
public:
56+
TRTInt8EntropyCalibrator(Int8CalibratorImpl impl) : impl_(impl) {}
57+
int getBatchSize() const override {return impl_.GetBatchSize();}
58+
bool getBatch(void* bindings[], const char* names[], int nbBindings) override {return impl_.GetBatch(bindings, names, nbBindings)};
59+
const void* readCalibrationCache(size_t& length) override {return impl_.ReadCalibrationCache(size_t& length)};
60+
void writeCalibrationCache(const void* cache, size_t length) override {impl_.WriteCalibrationCache(const void* cache, size_t length)};
61+
private:
62+
Int8CalibratorImpl impl_;
63+
};
64+
65+
nvinfer1::IInt8Calibrator create_int8_calibrator(QuantizationSettings settings);
66+
67+
} // quantization
68+
} // core
69+
} // trtorch

Diff for: cpp/ptq/BUILD

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
cc_binary(
4+
name = "ptq",
5+
srcs = [
6+
"main.cpp"
7+
],
8+
deps = [
9+
"@libtorch//:libtorch",
10+
"//cpp/api:trtorch"
11+
],
12+
)

Diff for: cpp/ptq/README.md

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# ptq
2+
3+
This is a short example application that shows how to use TRTorch to perform post-training quantization for a module.
4+
5+
## Compilation
6+
7+
``` shell
8+
bazel build //cpp/ptq --cxxopt="-DNDEBUG"
9+
```
10+
11+
If you want insight into what is going under the hood or need debug symbols
12+
13+
``` shell
14+
bazel build //cpp/ptq --compilation_mode=dbg
15+
```
16+
17+
## Usage
18+
19+
``` shell
20+
ptq
21+
```

Diff for: cpp/ptq/main.cpp

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#include "torch/script.h"
2+
#include "torch/csrc/api/include/torch/data/datasets/mnist.h"
3+
#include "trtorch/trtorch.h"
4+
5+
#include <iostream>
6+
#include <sstream>
7+
#include <memory>
8+
9+
int main(int argc, const char* argv[]) {
10+
if (argc < 3) {
11+
std::cerr << "usage: ptq <path-to-module> <path-to-mnist>\n";
12+
return -1;
13+
}
14+
15+
torch::jit::script::Module mod;
16+
try {
17+
// Deserialize the ScriptModule from a file using torch::jit::load().
18+
mod = torch::jit::load(argv[1]);
19+
}
20+
catch (const c10::Error& e) {
21+
std::cerr << "error loading the model\n";
22+
return -1;
23+
}
24+
25+
const std::string data_dir = std::string(argv[2]);
26+
auto calibration_dataset = torch::data::datasets::MNIST(data_dir, torch::data::datasets::MNIST::Mode::kTest)
27+
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
28+
.map(torch::data::transforms::Stack<>());
29+
auto calibration_dataloader = torch::data::make_data_loader(std::move(calibration_dataset), torch::data::DataLoaderOptions()
30+
.batch_size(32)
31+
.workers(1))
32+
33+
for (auto batch : batched_calibration_dataset) {
34+
std::cout << batch.data().sizes() << std::endl;
35+
}
36+
}

Diff for: cpp/trtorchexec/main.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs) {
1212
maxValue = fmax(tensor.abs().max().item<float>(), maxValue);
1313
}
1414
std::cout << "Max Difference: " << diff.abs().max().item<float>() << std::endl;
15-
return diff.abs().max().item<float>() <= 2e-6 * maxValue;
15+
return diff.abs().max().item<float>() <= 2e-5 * maxValue;
1616
}
1717

1818
bool almostEqual(const at::Tensor& a, const at::Tensor& b) {
@@ -25,8 +25,8 @@ int main(int argc, const char* argv[]) {
2525
<< " trtorchexec <path-to-exported-script-module> <min-input-size> <opt-input-size> <max-input-size>\n";
2626
return -1;
2727
}
28-
29-
28+
29+
3030
torch::jit::script::Module mod;
3131
try {
3232
// Deserialize the ScriptModule from a file using torch::jit::load().
@@ -38,7 +38,7 @@ int main(int argc, const char* argv[]) {
3838
}
3939

4040
mod.to(at::kCUDA);
41-
41+
4242
std::vector<std::vector<int64_t>> dims;
4343
for (int i = 2; i < argc; i++) {
4444
auto arg = std::string(argv[i]);
@@ -74,7 +74,7 @@ int main(int argc, const char* argv[]) {
7474
torch::jit::IValue jit_results_ivalues = mod.forward(jit_inputs_ivalues);
7575
std::vector<at::Tensor> jit_results;
7676
jit_results.push_back(jit_results_ivalues.toTensor());
77-
77+
7878
auto trt_mod = trtorch::CompileGraph(mod, dims);
7979
torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues);
8080
std::vector<at::Tensor> trt_results;

0 commit comments

Comments
 (0)