@@ -156,29 +156,6 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
156
156
return std::move (engine);
157
157
}
158
158
159
- // torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, CompileSpec cfg) {
160
- // // TODO: Should be doing a functional transform but need PR #31978
161
- // // [jit] More robust mangling
162
- // // torch::jit::script::Module new_mod = mod.clone();
163
- // torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
164
- // std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
165
- // for (const torch::jit::script::Method& method : mod.get_methods()) {
166
- // // Don't convert hidden methods
167
- // if (method.name().rfind("_", 0)) {
168
- // auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
169
- // auto new_g = std::make_shared<torch::jit::Graph>();
170
- // AddEngineToGraph(new_mod, new_g, engine);
171
- // auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
172
- // auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
173
- // new_mod.type()->addMethod(new_method);
174
- // new_method->setSchema(schema);
175
- // }
176
- // }
177
- //
178
- // return new_mod;
179
- // }
180
-
181
-
182
159
183
160
void AddSegmentedBlockToGraph (std::shared_ptr<torch::jit::Graph>& g, partitioning::SegmentedBlock &seg,
184
161
std::unordered_map<torch::jit::Value*, torch::jit::Value*> &old_to_new_g) {
@@ -198,7 +175,6 @@ void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitionin
198
175
}
199
176
}
200
177
201
- torch::jit::Node *node;
202
178
for (const auto n : seg.nodes ()) {
203
179
partitioning::cloneNode (n, g, old_to_new_g);
204
180
}
@@ -212,8 +188,7 @@ void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitionin
212
188
return ;
213
189
}
214
190
215
-
216
- torch::jit::script::Module CompileGraph (const torch::jit::script::Module& mod, CompileSpec cfg) {
191
+ torch::jit::script::Module CompileGraphWithFallback (const torch::jit::script::Module& mod, CompileSpec cfg) {
217
192
// TODO: Should be doing a functional transform but need PR #31978
218
193
// [jit] More robust mangling
219
194
// torch::jit::script::Module new_mod = mod.clone();
@@ -270,6 +245,33 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
270
245
return new_mod;
271
246
}
272
247
248
+
249
+ torch::jit::script::Module CompileGraph (const torch::jit::script::Module& mod, CompileSpec cfg) {
250
+ // TODO: not sure how to deal with duplicated code here, so just cut out a branch temporally
251
+ if (cfg.convert_info .engine_settings .torch_fallback .enabled ) {
252
+ return CompileGraphWithFallback (mod, cfg);
253
+ }
254
+ // TODO: Should be doing a functional transform but need PR #31978
255
+ // [jit] More robust mangling
256
+ // torch::jit::script::Module new_mod = mod.clone();
257
+ torch::jit::script::Module new_mod (mod._ivalue ()->name () + " _trt" );
258
+ std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
259
+ for (const torch::jit::script::Method& method : mod.get_methods ()) {
260
+ // Don't convert hidden methods
261
+ if (method.name ().rfind (" _" , 0 )) {
262
+ auto engine = ConvertGraphToTRTEngine (mod, method.name (), cfg);
263
+ auto new_g = std::make_shared<torch::jit::Graph>();
264
+ AddEngineToGraph (new_mod, new_g, engine);
265
+ auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
266
+ auto schema = GenerateGraphSchema (new_mod, new_method->name (), new_g);
267
+ new_mod.type ()->addMethod (new_method);
268
+ new_method->setSchema (schema);
269
+ }
270
+ }
271
+
272
+ return new_mod;
273
+ }
274
+
273
275
void set_device (const int gpu_id) {
274
276
TRTORCH_ASSERT (cudaSetDevice (gpu_id) == cudaSuccess, " Unable to set CUDA device: " << gpu_id);
275
277
}
0 commit comments