@@ -55,14 +55,17 @@ int main(int argc, const char* argv[]) {
55
55
dims.push_back (v);
56
56
}
57
57
58
+ auto extra_info = trtorch::ExtraInfo (dims);
59
+ extra_info.workspace_size = 1 << 24 ;
60
+
58
61
std::cout << " Checking operator support" << std::endl;
59
62
if (!trtorch::CheckMethodOperatorSupport (mod, " forward" )) {
60
63
std::cerr << " Method is not currently supported by TRTorch" << std::endl;
61
64
return -1 ;
62
65
}
63
66
64
67
std::cout << " Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
65
- auto engine = trtorch::ConvertGraphToTRTEngine (mod, " forward" , dims );
68
+ auto engine = trtorch::ConvertGraphToTRTEngine (mod, " forward" , extra_info );
66
69
std::ofstream out (" /tmp/engine_converted_from_jit.trt" );
67
70
out << engine;
68
71
out.close ();
@@ -75,14 +78,28 @@ int main(int argc, const char* argv[]) {
75
78
76
79
torch::jit::IValue jit_results_ivalues = mod.forward (jit_inputs_ivalues);
77
80
std::vector<at::Tensor> jit_results;
78
- jit_results.push_back (jit_results_ivalues.toTensor ());
81
+ if (jit_results_ivalues.isTensor ()) {
82
+ jit_results.push_back (jit_results_ivalues.toTensor ());
83
+ } else {
84
+ auto results = jit_results_ivalues.toTuple ()->elements ();
85
+ for (auto r : results) {
86
+ jit_results.push_back (r.toTensor ());
87
+ }
88
+ }
79
89
80
90
std::cout << " Compiling graph as module" << std::endl;
81
- auto trt_mod = trtorch::CompileGraph (mod, dims );
91
+ auto trt_mod = trtorch::CompileGraph (mod, extra_info );
82
92
std::cout << " Running TRT module" << std::endl;
83
93
torch::jit::IValue trt_results_ivalues = trt_mod.forward (trt_inputs_ivalues);
84
94
std::vector<at::Tensor> trt_results;
85
- trt_results.push_back (trt_results_ivalues.toTensor ());
95
+ if (trt_results_ivalues.isTensor ()) {
96
+ trt_results.push_back (trt_results_ivalues.toTensor ());
97
+ } else {
98
+ auto results = trt_results_ivalues.toTuple ()->elements ();
99
+ for (auto r : results) {
100
+ trt_results.push_back (r.toTensor ());
101
+ }
102
+ }
86
103
87
104
for (size_t i = 0 ; i < trt_results.size (); i++) {
88
105
almostEqual (jit_results[i], trt_results[i].reshape_as (jit_results[i]));
0 commit comments