Skip to content

Commit 80808b7

Browse files
committed
feat(//cpp/trtorchexec): TRTorch exec now supports checking correctness
of multiple outputs Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 8171f79 commit 80808b7

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

Diff for: cpp/trtorchexec/main.cpp

+21-4
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,17 @@ int main(int argc, const char* argv[]) {
5555
dims.push_back(v);
5656
}
5757

58+
auto extra_info = trtorch::ExtraInfo(dims);
59+
extra_info.workspace_size = 1 << 24;
60+
5861
std::cout << "Checking operator support" << std::endl;
5962
if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
6063
std::cerr << "Method is not currently supported by TRTorch" << std::endl;
6164
return -1;
6265
}
6366

6467
std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
65-
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", dims);
68+
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", extra_info);
6669
std::ofstream out("/tmp/engine_converted_from_jit.trt");
6770
out << engine;
6871
out.close();
@@ -75,14 +78,28 @@ int main(int argc, const char* argv[]) {
7578

7679
torch::jit::IValue jit_results_ivalues = mod.forward(jit_inputs_ivalues);
7780
std::vector<at::Tensor> jit_results;
78-
jit_results.push_back(jit_results_ivalues.toTensor());
81+
if (jit_results_ivalues.isTensor()) {
82+
jit_results.push_back(jit_results_ivalues.toTensor());
83+
} else {
84+
auto results = jit_results_ivalues.toTuple()->elements();
85+
for (auto r : results) {
86+
jit_results.push_back(r.toTensor());
87+
}
88+
}
7989

8090
std::cout << "Compiling graph as module" << std::endl;
81-
auto trt_mod = trtorch::CompileGraph(mod, dims);
91+
auto trt_mod = trtorch::CompileGraph(mod, extra_info);
8292
std::cout << "Running TRT module" << std::endl;
8393
torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues);
8494
std::vector<at::Tensor> trt_results;
85-
trt_results.push_back(trt_results_ivalues.toTensor());
95+
if (trt_results_ivalues.isTensor()) {
96+
trt_results.push_back(trt_results_ivalues.toTensor());
97+
} else {
98+
auto results = trt_results_ivalues.toTuple()->elements();
99+
for (auto r : results) {
100+
trt_results.push_back(r.toTensor());
101+
}
102+
}
86103

87104
for (size_t i = 0; i < trt_results.size(); i++) {
88105
almostEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]));

0 commit comments

Comments
 (0)