@@ -68,8 +68,8 @@ TRTEngine::TRTEngine(
68
68
uint64_t inputs = 0 ;
69
69
uint64_t outputs = 0 ;
70
70
71
- for (int64_t x = 0 ; x < cuda_engine->getNbBindings (); x ++) {
72
- std::string bind_name = cuda_engine->getBindingName (x );
71
+ for (int64_t trt_idx = 0 ; trt_idx < cuda_engine->getNbIOTensors (); trt_idx ++) {
72
+ std::string bind_name = cuda_engine->getIOTensorName (trt_idx );
73
73
LOG_DEBUG (" Binding name: " << bind_name);
74
74
auto delim = bind_name.find (" ." );
75
75
if (delim == std::string::npos) {
@@ -80,46 +80,45 @@ TRTEngine::TRTEngine(
80
80
<< bind_name
81
81
<< " \n Ensure module was compiled with Torch-TensorRT.ts or follows Torch-TensorRT Runtime conventions" );
82
82
}
83
-
84
83
std::string idx_s = bind_name.substr (delim + 1 );
85
- uint64_t idx = static_cast <uint64_t >(std::stoi (idx_s));
84
+ uint64_t pyt_idx = static_cast <uint64_t >(std::stoi (idx_s));
86
85
87
- if (cuda_engine->bindingIsInput (x) ) {
86
+ if (cuda_engine->getTensorIOMode (bind_name. c_str ()) == nvinfer1::TensorIOMode:: kINPUT ) {
88
87
inputs++;
89
- in_binding_map[x ] = idx ;
90
- LOG_DEBUG (" TRT Binding: " << x << " : PYT Input: " << idx );
88
+ in_binding_map[trt_idx ] = pyt_idx ;
89
+ LOG_DEBUG (" TRT Binding index : " << trt_idx << " corresponds to PYT Input index : " << pyt_idx );
91
90
} else {
92
91
outputs++;
93
- out_binding_map[x ] = idx ;
94
- LOG_DEBUG (" TRT Binding: " << x << " : PYT Output: " << idx );
92
+ out_binding_map[trt_idx ] = pyt_idx ;
93
+ LOG_DEBUG (" TRT Binding index : " << trt_idx << " corresponds to PYT Output: " << pyt_idx );
95
94
}
96
95
}
97
96
98
97
num_io = std::make_pair (inputs, outputs);
99
98
in_binding_names.resize (inputs);
100
99
out_binding_names.resize (outputs);
101
-
102
- for (int64_t x = 0 ; x < cuda_engine->getNbBindings (); x++) {
103
- std::string bind_name = cuda_engine->getBindingName (x);
104
- if (cuda_engine->bindingIsInput (x)) {
100
+ for (int64_t x = 0 ; x < cuda_engine->getNbIOTensors (); x++) {
101
+ std::string bind_name = cuda_engine->getIOTensorName (x);
102
+ if (cuda_engine->getTensorIOMode (bind_name.c_str ()) == nvinfer1::TensorIOMode::kINPUT ) {
105
103
in_binding_names[in_binding_map.at (x)] = bind_name;
106
104
} else {
107
105
out_binding_names[out_binding_map.at (x)] = bind_name;
108
106
}
109
107
}
110
108
} else {
111
- uint64_t inputs = _in_binding_names.size ();
112
- in_binding_names.resize (inputs );
113
- for (size_t pyt_idx = 0 ; pyt_idx < inputs ; pyt_idx++) {
109
+ uint64_t inputs_size = _in_binding_names.size ();
110
+ in_binding_names.resize (inputs_size );
111
+ for (size_t pyt_idx = 0 ; pyt_idx < inputs_size ; pyt_idx++) {
114
112
auto binding_name = _in_binding_names[pyt_idx];
115
113
auto trt_idx = cuda_engine->getBindingIndex (binding_name.c_str ());
116
- TORCHTRT_CHECK ((trt_idx >= 0 ), " Could not find a TensorRT engine binding for input named " << binding_name );
114
+ std::string engine_binded_name = cuda_engine-> getIOTensorName (pyt_idx );
117
115
TORCHTRT_CHECK (
118
- cuda_engine->bindingIsInput (trt_idx),
116
+ (binding_name == engine_binded_name),
117
+ " Could not find a TensorRT engine binding for input named " << binding_name);
118
+ TORCHTRT_CHECK (
119
+ (cuda_engine->getTensorIOMode (binding_name.c_str ()) == nvinfer1::TensorIOMode::kINPUT ),
119
120
" Binding " << binding_name << " specified as input but found as output in TensorRT engine" );
120
- LOG_DEBUG (
121
- " Input binding name: " << binding_name << " (trt binding idx: " << trt_idx << " , "
122
- << " pyt arg idx: " << pyt_idx << " )" );
121
+ LOG_DEBUG (" Input binding name: " << binding_name << " pyt arg idx: " << pyt_idx << " )" );
123
122
in_binding_map[trt_idx] = pyt_idx;
124
123
in_binding_names[pyt_idx] = _in_binding_names[pyt_idx];
125
124
}
@@ -129,17 +128,18 @@ TRTEngine::TRTEngine(
129
128
for (size_t pyt_idx = 0 ; pyt_idx < outputs; pyt_idx++) {
130
129
auto binding_name = _out_binding_names[pyt_idx];
131
130
auto trt_idx = cuda_engine->getBindingIndex (binding_name.c_str ());
132
- TORCHTRT_CHECK ((trt_idx >= 0 ), " Could not find a TensorRT engine binding for output named " << binding_name);
131
+ std::string engine_binded_name = cuda_engine->getIOTensorName (inputs_size + pyt_idx);
132
+ TORCHTRT_CHECK (
133
+ (binding_name == engine_binded_name),
134
+ " Could not find a TensorRT engine binding for output named " << binding_name);
133
135
TORCHTRT_CHECK (
134
- !cuda_engine->bindingIsInput (trt_idx ),
136
+ !( cuda_engine->getTensorIOMode (binding_name. c_str ()) == nvinfer1::TensorIOMode:: kINPUT ),
135
137
" Binding " << binding_name << " specified as output but found as input in TensorRT engine" );
136
- LOG_DEBUG (
137
- " Output binding name: " << binding_name << " (trt binding idx: " << trt_idx << " , "
138
- << " pyt return idx: " << pyt_idx << " )" );
138
+ LOG_DEBUG (" Output binding name: " << binding_name << " pyt return idx: " << inputs_size + pyt_idx << " )" );
139
139
out_binding_map[trt_idx] = pyt_idx;
140
140
out_binding_names[pyt_idx] = binding_name;
141
141
}
142
- num_io = std::make_pair (inputs , outputs);
142
+ num_io = std::make_pair (inputs_size , outputs);
143
143
}
144
144
145
145
#ifndef NDEBUG
@@ -149,10 +149,10 @@ TRTEngine::TRTEngine(
149
149
}
150
150
151
151
TRTEngine::~TRTEngine () {
152
+ rt.reset ();
152
153
trt_engine_profiler.reset ();
153
154
exec_ctx.reset ();
154
155
cuda_engine.reset ();
155
- rt.reset ();
156
156
}
157
157
158
158
void TRTEngine::disable_profiling () {
@@ -164,7 +164,7 @@ void TRTEngine::disable_profiling() {
164
164
}
165
165
166
166
void TRTEngine::dump_engine_layer_info_to_file (const std::string& path) {
167
- auto inspector = cuda_engine->createEngineInspector ();
167
+ auto inspector = make_trt ( cuda_engine->createEngineInspector () );
168
168
std::ofstream f (path);
169
169
f << std::string (inspector->getEngineInformation (nvinfer1::LayerInformationFormat::kJSON ));
170
170
f.close ();
@@ -208,23 +208,23 @@ std::string TRTEngine::to_str() const {
208
208
std::stringstream ss;
209
209
ss << " Torch-TensorRT TensorRT Engine:" << std::endl;
210
210
ss << " Name: " << name << std::endl;
211
- ss << " Bindings: { " << std::endl;
212
- for (int64_t x = 0 ; x < cuda_engine-> getNbBindings (); x ++) {
213
- if (cuda_engine-> bindingIsInput (x)) {
214
- const uint64_t pyt_idx = in_binding_map. at (x) ;
215
- ss << " ( " << x << " : " << in_binding_names. at (pyt_idx) << " ) Input: [ " << std::endl;
216
- ss << " pytorch arg idx: " << pyt_idx << std::endl;
217
- ss << " shape: " << exec_ctx-> getBindingDimensions (x) << std::endl;
218
- ss << " dtype: " << util::TRTDataTypeToScalarType (exec_ctx-> getEngine (). getBindingDataType (x)) << std::endl;
219
- ss << " ]" << std::endl;
220
- } else {
221
- const uint64_t pyt_idx = out_binding_map. at (x);
222
- ss << " ( " << x << " : " << out_binding_names. at (pyt_idx) << " ) Output: [ " << std::endl;
223
- ss << " pytorch return idx : " << pyt_idx << std::endl;
224
- ss << " shape : " << exec_ctx-> getBindingDimensions (x) << std::endl;
225
- ss << " dtype: " << util::TRTDataTypeToScalarType (exec_ctx-> getEngine (). getBindingDataType (x)) << std::endl;
226
- ss << " ] " << std::endl;
227
- }
211
+ ss << " Inputs: [ " << std::endl;
212
+ for (uint64_t i = 0 ; i < num_io. first ; i ++) {
213
+ ss << " id: " << i << std::endl;
214
+ ss << " shape: " << exec_ctx-> getTensorShape ( std::string ( " input_ " + str (i)). c_str ()) << std::endl ;
215
+ ss << " dtype: "
216
+ << util::TRTDataTypeToScalarType (exec_ctx-> getEngine (). getTensorDataType ( std::string ( " input_ " + str (i)). c_str ()))
217
+ << std::endl;
218
+ }
219
+ ss << " ]" << std::endl;
220
+ ss << " Outputs: [ " << std::endl;
221
+ for ( uint64_t o = 0 ; o < num_io. second ; o++) {
222
+ ss << " id : " << o << std::endl;
223
+ ss << " shape : " << exec_ctx-> getTensorShape ( std::string ( " output_ " + str (o)). c_str ()) << std::endl;
224
+ ss << " dtype : "
225
+ << util::TRTDataTypeToScalarType (
226
+ exec_ctx-> getEngine (). getTensorDataType ( std::string ( " output_ " + str (o)). c_str ()))
227
+ << std::endl;
228
228
}
229
229
ss << " }" << std::endl;
230
230
ss << " Device: " << device_info << std::endl;
0 commit comments