Skip to content

Commit 98527d2

Browse files
committed
fix(//cpp/benchmark): reorder benchmark so FP16 bn issue in JIT doesnt
interfere with TRTorch Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 4741246 commit 98527d2

File tree

2 files changed

+30
-18
lines changed

2 files changed

+30
-18
lines changed

Diff for: cpp/benchmark/README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Benchmarking
22

3-
This is a quick benchmarking application for TRTorch. It lets you run supported TorchScript modules both in JIT and TRT and returns the average runtime and throughput.
3+
This is a quick benchmarking application for TRTorch. It lets you run supported TorchScript modules both in JIT and TRT and returns the average runtime and throughput.
44

55
## Compilation / Usage
66

@@ -20,12 +20,14 @@ bazel run //cpp/benchmark --cxxopt="-DNDEBUG" --cxxopt="-DJIT" --cxxopt="-DTRT"
2020

2121
### Options
2222

23-
You can run a module with JIT or TRT via TRTorch in either FP32 or FP16. These options are controlled by preprocessor directives.
23+
You can run a module with JIT or TRT via TRTorch in either FP32 or FP16. These options are controlled by preprocessor directives.
2424

2525
- To enable JIT profiling, add the argument `--cxxopt="-DJIT"`
2626

2727
- To enable TRT profiling, add the argument `--cxxopt="-DTRT"`
2828

2929
- To enable FP16 execution, add the argument `--cxxopt="-DHALF"`
3030

31+
- To also save the TRT engine, add the argument `--cxxopt="-DSAVE_ENGINE"`
32+
3133
> It's suggested to also define `--cxxopt="-DNDEBUG"` to supress debug information

Diff for: cpp/benchmark/main.cpp

+26-16
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,6 @@ int main(int argc, const char* argv[]) {
105105

106106
mod.to(at::kCUDA);
107107

108-
#ifdef HALF
109-
mod.to(torch::kHalf);
110-
for (auto layer : mod.named_modules()) {
111-
if (layer.name.find(".bn") != std::string::npos) {
112-
layer.value.to(torch::kFloat);
113-
}
114-
}
115-
#endif
116-
117108
std::vector<std::vector<int64_t>> dims;
118109
for (int i = 2; i < argc; i++) {
119110
auto arg = std::string(argv[i]);
@@ -129,23 +120,42 @@ int main(int argc, const char* argv[]) {
129120

130121
at::globalContext().setBenchmarkCuDNN(true);
131122

132-
#ifdef JIT
133-
auto jit_runtimes = benchmark_module(mod, dims[0]);
134-
print_avg_std_dev("JIT", jit_runtimes, dims[0][0]);
135-
#endif
136-
137123
#ifdef TRT
138124
auto extra_info = trtorch::ExtraInfo(dims);
139-
extra_info.workspace_size = 1 << 24;
125+
extra_info.workspace_size = 1 << 20;
140126

141127
#ifdef HALF
142-
extra_info.op_precision = at::kHalf;
128+
extra_info.op_precision = torch::kF16;
143129
#endif
144130

145131
auto trt_mod = trtorch::CompileGraph(mod, extra_info);
132+
133+
#ifdef SAVE_ENGINE
134+
std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
135+
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", extra_info);
136+
std::ofstream out("/tmp/engine_converted_from_jit.trt");
137+
out << engine;
138+
out.close();
139+
#endif
140+
146141
auto trt_runtimes = benchmark_module(trt_mod, dims[0]);
147142
print_avg_std_dev("JIT/TRT", trt_runtimes, dims[0][0]);
148143
#endif
149144

145+
146+
#ifdef HALF
147+
mod.to(torch::kHalf);
148+
for (auto layer : mod.named_modules()) {
149+
if (layer.name.find(".bn") != std::string::npos) {
150+
layer.value.to(torch::kFloat);
151+
}
152+
}
153+
#endif
154+
155+
#ifdef JIT
156+
auto jit_runtimes = benchmark_module(mod, dims[0]);
157+
print_avg_std_dev("JIT", jit_runtimes, dims[0][0]);
158+
#endif
159+
150160
std::cout << "ok\n";
151161
}

0 commit comments

Comments
 (0)