Skip to content

Commit d8f5d29

Browse files
committed
feat(//cpp/int8/qat): QAT application release
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 004bf53 commit d8f5d29

20 files changed

+1819
-0
lines changed

Diff for: cpp/int8/benchmark/BUILD

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
cc_library(
4+
name = "benchmark",
5+
srcs = [
6+
"benchmark.cpp",
7+
"timer.h",
8+
],
9+
hdrs = [
10+
"benchmark.h",
11+
],
12+
deps = [
13+
"//cpp/api:trtorch",
14+
"@libtorch",
15+
"@libtorch//:caffe2",
16+
],
17+
)

Diff for: cpp/int8/benchmark/benchmark.cpp

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#include "ATen/Context.h"
2+
#include "c10/cuda/CUDACachingAllocator.h"
3+
#include "cuda_runtime_api.h"
4+
#include "torch/script.h"
5+
#include "torch/torch.h"
6+
#include "trtorch/trtorch.h"
7+
8+
#include "timer.h"
9+
10+
#define NUM_WARMUP_RUNS 20
11+
#define NUM_RUNS 100
12+
13+
// Benchmaking code
14+
void print_avg_std_dev(std::string type, std::vector<float>& runtimes, uint64_t batch_size) {
15+
float avg_runtime = std::accumulate(runtimes.begin(), runtimes.end(), 0.0) / runtimes.size();
16+
float fps = (1000.f / avg_runtime) * batch_size;
17+
std::cout << "[" << type << "]: batch_size: " << batch_size << "\n Average latency: " << avg_runtime
18+
<< " ms\n Average FPS: " << fps << " fps" << std::endl;
19+
20+
std::vector<float> rt_diff(runtimes.size());
21+
std::transform(runtimes.begin(), runtimes.end(), rt_diff.begin(), [avg_runtime](float x) { return x - avg_runtime; });
22+
float rt_sq_sum = std::inner_product(rt_diff.begin(), rt_diff.end(), rt_diff.begin(), 0.0);
23+
float rt_std_dev = std::sqrt(rt_sq_sum / runtimes.size());
24+
25+
std::vector<float> fps_diff(runtimes.size());
26+
std::transform(runtimes.begin(), runtimes.end(), fps_diff.begin(), [fps, batch_size](float x) {
27+
return ((1000.f / x) * batch_size) - fps;
28+
});
29+
float fps_sq_sum = std::inner_product(fps_diff.begin(), fps_diff.end(), fps_diff.begin(), 0.0);
30+
float fps_std_dev = std::sqrt(fps_sq_sum / runtimes.size());
31+
std::cout << " Latency Standard Deviation: " << rt_std_dev << "\n FPS Standard Deviation: " << fps_std_dev
32+
<< "\n(excluding initial warmup runs)" << std::endl;
33+
}
34+
35+
std::vector<float> benchmark_module(torch::jit::script::Module& mod, std::vector<int64_t> shape) {
36+
auto execution_timer = timers::PreciseCPUTimer();
37+
std::vector<float> execution_runtimes;
38+
39+
for (uint64_t i = 0; i < NUM_WARMUP_RUNS; i++) {
40+
std::vector<torch::jit::IValue> inputs_ivalues;
41+
auto in = at::rand(shape, {at::kCUDA});
42+
#ifdef HALF
43+
in = in.to(torch::kHalf);
44+
#endif
45+
inputs_ivalues.push_back(in.clone());
46+
47+
cudaDeviceSynchronize();
48+
mod.forward(inputs_ivalues);
49+
cudaDeviceSynchronize();
50+
}
51+
52+
for (uint64_t i = 0; i < NUM_RUNS; i++) {
53+
std::vector<torch::jit::IValue> inputs_ivalues;
54+
auto in = at::rand(shape, {at::kCUDA});
55+
#ifdef HALF
56+
in = in.to(torch::kHalf);
57+
#endif
58+
inputs_ivalues.push_back(in.clone());
59+
cudaDeviceSynchronize();
60+
61+
execution_timer.start();
62+
mod.forward(inputs_ivalues);
63+
cudaDeviceSynchronize();
64+
execution_timer.stop();
65+
66+
auto time = execution_timer.milliseconds();
67+
execution_timer.reset();
68+
execution_runtimes.push_back(time);
69+
70+
c10::cuda::CUDACachingAllocator::emptyCache();
71+
}
72+
return execution_runtimes;
73+
}

Diff for: cpp/int8/benchmark/benchmark.h

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#pragma once
2+
3+
void print_avg_std_dev(std::string type, std::vector<float>& runtimes, uint64_t batch_size);
4+
std::vector<float> benchmark_module(torch::jit::script::Module& mod, std::vector<int64_t> shape);

Diff for: cpp/int8/benchmark/timer.h

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#pragma once
2+
#include <chrono>
3+
4+
namespace timers {
5+
class TimerBase {
6+
public:
7+
virtual void start() {}
8+
virtual void stop() {}
9+
float microseconds() const noexcept {
10+
return mMs * 1000.f;
11+
}
12+
float milliseconds() const noexcept {
13+
return mMs;
14+
}
15+
float seconds() const noexcept {
16+
return mMs / 1000.f;
17+
}
18+
void reset() noexcept {
19+
mMs = 0.f;
20+
}
21+
22+
protected:
23+
float mMs{0.0f};
24+
};
25+
26+
template <typename Clock>
27+
class CPUTimer : public TimerBase {
28+
public:
29+
using clock_type = Clock;
30+
31+
void start() {
32+
mStart = Clock::now();
33+
}
34+
void stop() {
35+
mStop = Clock::now();
36+
mMs += std::chrono::duration<float, std::milli>{mStop - mStart}.count();
37+
}
38+
39+
private:
40+
std::chrono::time_point<Clock> mStart, mStop;
41+
}; // class CPUTimer
42+
43+
using PreciseCPUTimer = CPUTimer<std::chrono::high_resolution_clock>;
44+
} // namespace timers

Diff for: cpp/int8/datasets/BUILD

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
cc_library(
4+
name = "cifar10",
5+
srcs = [
6+
"cifar10.cpp",
7+
],
8+
hdrs = [
9+
"cifar10.h",
10+
],
11+
deps = [
12+
"@libtorch",
13+
],
14+
)

Diff for: cpp/int8/datasets/cifar10.cpp

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
// #include "cpp/int8/ptq/datasets/cifar10.h"
2+
#include "cifar10.h"
3+
#include "torch/data/example.h"
4+
#include "torch/torch.h"
5+
#include "torch/types.h"
6+
7+
#include <cstddef>
8+
#include <fstream>
9+
#include <iostream>
10+
#include <memory>
11+
#include <sstream>
12+
#include <string>
13+
#include <utility>
14+
#include <vector>
15+
16+
namespace datasets {
17+
namespace {
18+
constexpr const char* kTrainFilenamePrefix = "data_batch_";
19+
constexpr const uint32_t kNumTrainFiles = 5;
20+
constexpr const char* kTestFilename = "test_batch.bin";
21+
constexpr const size_t kLabelSize = 1; // B
22+
constexpr const size_t kImageSize = 3072; // B
23+
constexpr const size_t kImageDim = 32;
24+
constexpr const size_t kImageChannels = 3;
25+
constexpr const size_t kBatchSize = 10000;
26+
27+
std::pair<torch::Tensor, torch::Tensor> read_batch(const std::string& path) {
28+
std::ifstream batch;
29+
batch.open(path, std::ios::in | std::ios::binary | std::ios::ate);
30+
31+
auto file_size = batch.tellg();
32+
std::unique_ptr<char[]> buf(new char[file_size]);
33+
34+
batch.seekg(0, std::ios::beg);
35+
batch.read(buf.get(), file_size);
36+
batch.close();
37+
38+
std::vector<uint8_t> labels;
39+
std::vector<torch::Tensor> images;
40+
labels.reserve(kBatchSize);
41+
images.reserve(kBatchSize);
42+
43+
for (size_t i = 0; i < kBatchSize; i++) {
44+
uint8_t label = buf[i * (kImageSize + kLabelSize)];
45+
std::vector<uint8_t> image;
46+
image.reserve(kImageSize);
47+
std::copy(
48+
&buf[i * (kImageSize + kLabelSize) + 1],
49+
&buf[i * (kImageSize + kLabelSize) + kImageSize],
50+
std::back_inserter(image));
51+
labels.push_back(label);
52+
auto image_tensor =
53+
torch::from_blob(image.data(), {kImageChannels, kImageDim, kImageDim}, torch::TensorOptions().dtype(torch::kU8))
54+
.to(torch::kF32);
55+
images.push_back(image_tensor);
56+
}
57+
58+
auto labels_tensor =
59+
torch::from_blob(labels.data(), {kBatchSize}, torch::TensorOptions().dtype(torch::kU8)).to(torch::kF32);
60+
assert(labels_tensor.size(0) == kBatchSize);
61+
62+
auto images_tensor = torch::stack(images);
63+
assert(images_tensor.size(0) == kBatchSize);
64+
65+
return std::make_pair(images_tensor, labels_tensor);
66+
}
67+
68+
std::pair<torch::Tensor, torch::Tensor> read_train_data(const std::string& root) {
69+
std::vector<torch::Tensor> images, targets;
70+
for (uint32_t i = 1; i <= 5; i++) {
71+
std::stringstream ss;
72+
ss << root << '/' << kTrainFilenamePrefix << i << ".bin";
73+
auto batch = read_batch(ss.str());
74+
images.push_back(batch.first);
75+
targets.push_back(batch.second);
76+
}
77+
78+
torch::Tensor image_tensor =
79+
std::accumulate(++images.begin(), images.end(), *images.begin(), [&](torch::Tensor a, torch::Tensor b) {
80+
return torch::cat({a, b}, 0);
81+
});
82+
torch::Tensor target_tensor =
83+
std::accumulate(++targets.begin(), targets.end(), *targets.begin(), [&](torch::Tensor a, torch::Tensor b) {
84+
return torch::cat({a, b}, 0);
85+
});
86+
87+
return std::make_pair(image_tensor, target_tensor);
88+
}
89+
90+
std::pair<torch::Tensor, torch::Tensor> read_test_data(const std::string& root) {
91+
std::stringstream ss;
92+
ss << root << '/' << kTestFilename;
93+
return read_batch(ss.str());
94+
}
95+
} // namespace
96+
97+
CIFAR10::CIFAR10(const std::string& root, Mode mode) : mode_(mode) {
98+
std::pair<torch::Tensor, torch::Tensor> data;
99+
if (mode_ == Mode::kTrain) {
100+
data = read_train_data(root);
101+
} else {
102+
data = read_test_data(root);
103+
}
104+
105+
images_ = std::move(data.first);
106+
targets_ = std::move(data.second);
107+
assert(images_.sizes()[0] == images_.sizes()[0]);
108+
}
109+
110+
torch::data::Example<> CIFAR10::get(size_t index) {
111+
return {images_[index], targets_[index]};
112+
}
113+
114+
c10::optional<size_t> CIFAR10::size() const {
115+
return images_.size(0);
116+
}
117+
118+
bool CIFAR10::is_train() const noexcept {
119+
return mode_ == Mode::kTrain;
120+
}
121+
122+
const torch::Tensor& CIFAR10::images() const {
123+
return images_;
124+
}
125+
126+
const torch::Tensor& CIFAR10::targets() const {
127+
return targets_;
128+
}
129+
130+
CIFAR10&& CIFAR10::use_subset(int64_t new_size) {
131+
assert(new_size <= images_.sizes()[0]);
132+
images_ = images_.slice(0, 0, new_size);
133+
targets_ = targets_.slice(0, 0, new_size);
134+
return std::move(*this);
135+
}
136+
137+
} // namespace datasets

Diff for: cpp/int8/datasets/cifar10.h

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#pragma once
2+
3+
#include "torch/data/datasets/base.h"
4+
#include "torch/data/example.h"
5+
#include "torch/types.h"
6+
7+
#include <cstddef>
8+
#include <string>
9+
10+
namespace datasets {
11+
// The CIFAR10 Dataset
12+
class CIFAR10 : public torch::data::datasets::Dataset<CIFAR10> {
13+
public:
14+
// The mode in which the dataset is loaded
15+
enum class Mode { kTrain, kTest };
16+
17+
// Loads CIFAR10 from un-tarred file
18+
// Dataset can be found
19+
// https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz Root path should be
20+
// the directory that contains the content of tarball
21+
explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain);
22+
23+
// Returns the pair at index in the dataset
24+
torch::data::Example<> get(size_t index) override;
25+
26+
// The size of the dataset
27+
c10::optional<size_t> size() const override;
28+
29+
// The mode the dataset is in
30+
bool is_train() const noexcept;
31+
32+
// Returns all images stacked into a single tensor
33+
const torch::Tensor& images() const;
34+
35+
// Returns all targets stacked into a single tensor
36+
const torch::Tensor& targets() const;
37+
38+
// Trims the dataset to the first n pairs
39+
CIFAR10&& use_subset(int64_t new_size);
40+
41+
private:
42+
Mode mode_;
43+
torch::Tensor images_, targets_;
44+
};
45+
} // namespace datasets

Diff for: cpp/int8/ptq/BUILD

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
cc_binary(
4+
name = "ptq",
5+
srcs = [
6+
"main.cpp",
7+
],
8+
copts = [
9+
"-pthread",
10+
],
11+
linkopts = [
12+
"-lpthread",
13+
],
14+
deps = [
15+
"//cpp/api:trtorch",
16+
"//cpp/int8/benchmark",
17+
"//cpp/int8/datasets:cifar10",
18+
"@libtorch",
19+
"@libtorch//:caffe2",
20+
"@tensorrt//:nvinfer",
21+
],
22+
)

0 commit comments

Comments
 (0)