@@ -138,7 +138,8 @@ partitioning::GraphAndMapping BuildHybridGraph(
138
138
torch::jit::Block* block,
139
139
CompileSpec cfg,
140
140
ir::StaticParams static_params,
141
- ir::CollectionTypeMap first_use_types) {
141
+ ir::CollectionTypeMap first_use_types,
142
+ bool expect_full_compilation = false ) {
142
143
auto convert_info = cfg.convert_info ;
143
144
auto partitioning_info = cfg.partitioning_info ;
144
145
@@ -149,17 +150,20 @@ partitioning::GraphAndMapping BuildHybridGraph(
149
150
// TODO: Combine this within partition call
150
151
partitioning::populateInputIValues (&partitioning_ctx);
151
152
152
- partitioning::partition (&partitioning_ctx);
153
+ partitioning::partition (&partitioning_ctx, expect_full_compilation );
153
154
154
155
for (auto & partitioned_block : partitioning_ctx.partitioned_blocks ) {
155
156
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second ;
157
+ int num_torch_segments = 0 ;
158
+ int num_trt_segments = 0 ;
156
159
157
160
for (auto & seg_block : segmented_blocks) {
158
161
LOG_INFO (" Block segment:" << seg_block);
159
162
std::ostringstream trt_engine_id;
160
163
trt_engine_id << reinterpret_cast <const int *>(&seg_block);
161
164
162
165
if (seg_block.target () == partitioning::SegmentedBlock::kTensorRT ) {
166
+ num_trt_segments++;
163
167
auto inputs = seg_block.construct_inputs_spec ();
164
168
// update the input ranges for each segments
165
169
convert_info.inputs = ir::associate_specs_with_inputs (seg_block.g (), inputs, static_params);
@@ -180,8 +184,32 @@ partitioning::GraphAndMapping BuildHybridGraph(
180
184
true );
181
185
182
186
seg_block.update_graph (temp_g);
187
+ } else {
188
+ num_torch_segments++;
189
+
190
+ // If full compilation is expected, ensure that all operators in Torch blocks are
191
+ // for collections processing
192
+ if (expect_full_compilation) {
193
+ for (auto torch_node : seg_block.block ()->nodes ()) {
194
+ if (partitioning::CollectionNodeKinds.find (torch_node->kind ()) == partitioning::CollectionNodeKinds.end ()) {
195
+ TORCHTRT_THROW_ERROR (
196
+ " Full compilation specified but node "
197
+ << *torch_node
198
+ << " is set to run in PyTorch due to either lack of support in TensorRT or graph partitioning rules."
199
+ << " Try recompiling with require_full_compilation=False." );
200
+ }
201
+ }
202
+ }
183
203
}
184
204
}
205
+
206
+ // If full compilation is expected, cannot have more than 2 Torch segments
207
+ // (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
208
+ if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1 )) {
209
+ TORCHTRT_THROW_ERROR (
210
+ " Full compilation was requested but unable to convert all operations to TensorRT."
211
+ << " Try recompiling with require_full_compilation=False." );
212
+ }
185
213
}
186
214
187
215
return partitioning::stitch (&partitioning_ctx, block);
@@ -191,7 +219,8 @@ ir::TypeMap MapInputsAndDetermineDTypes(
191
219
CompileSpec& cfg,
192
220
std::shared_ptr<torch::jit::Graph>& g,
193
221
ir::StaticParams& static_params,
194
- ir::CollectionTypeMap& first_use_type_map) {
222
+ ir::CollectionTypeMap& first_use_type_map,
223
+ bool requires_collection_handling = false ) {
195
224
cfg.convert_info .collection_input_spec_map =
196
225
std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
197
226
cfg.partitioning_info .collection_input_spec_map =
@@ -226,7 +255,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
226
255
" Cannot infer input type from calcuations in graph for input "
227
256
<< in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
228
257
spec[i].dtype = at::kFloat ;
229
- } else if (spec[i].dtype_is_user_defined && cfg.partitioning_info .enabled ) {
258
+ } else if (spec[i].dtype_is_user_defined && ( cfg.partitioning_info .enabled || requires_collection_handling) ) {
230
259
if (!est_type_opt[i]) {
231
260
LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
232
261
std::stringstream ss;
@@ -297,6 +326,11 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
297
326
return engine;
298
327
}
299
328
329
+ bool userRequestedFallback (CompileSpec& cfg) {
330
+ return cfg.lower_info .forced_fallback_modules .size () != 0 ||
331
+ cfg.partitioning_info .forced_fallback_operators .size () != 0 ;
332
+ }
333
+
300
334
torch::jit::Module CompileGraph (const torch::jit::Module& mod, CompileSpec cfg) {
301
335
torch::jit::Module new_mod (mod._ivalue ()->name () + " _trt" );
302
336
@@ -315,8 +349,17 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
315
349
// Infer the type of an input from the weights of the calculation
316
350
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection (g->block ());
317
351
352
+ // Determine if the block is convertible/has collection output, and based on the result,
353
+ // whether full compilation can be expected
354
+ auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
355
+ auto outputIsCollection = conversion::OutputIsCollection (g->block ());
356
+ auto requires_collection_handling = (isBlockConvertible && outputIsCollection);
357
+
358
+ // Determine whether user specifications necessitate partitioning
359
+ auto isFallbackRequested = userRequestedFallback (cfg);
360
+
318
361
// Extract map of IValue to DType
319
- auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
362
+ auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types, requires_collection_handling );
320
363
321
364
// Check whether any of the input types are Long
322
365
bool user_requested_long = false ;
@@ -330,20 +373,28 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
330
373
user_requested_long &= (casts_inserted > 0 );
331
374
}
332
375
333
- auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
334
- auto outputIsCollection = conversion::OutputIsCollection (g->block ());
335
- if (cfg.partitioning_info .enabled && !user_requested_long &&
336
- (cfg.lower_info .forced_fallback_modules .size () == 0 &&
337
- cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) &&
338
- !outputIsCollection) {
376
+ // Partitioning is required if:
377
+ // 1. User requested some modules/operators fallback
378
+ // 2. The block (graph) cannot be converted due to operator coverage
379
+ // 3. The output of the graph is a collection
380
+ // 4. The user requested a non-TRT data type input
381
+ auto isPartitioningRequired =
382
+ (isFallbackRequested || !isBlockConvertible || outputIsCollection || user_requested_long);
383
+
384
+ // The user did not require full compilation, but the model can be fully compiled
385
+ if (cfg.partitioning_info .enabled && !isPartitioningRequired) {
339
386
LOG_INFO (" Skipping partitioning since model is fully supported" );
340
387
}
341
388
342
- if (cfg.partitioning_info .enabled &&
343
- (!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
344
- cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
345
- outputIsCollection || user_requested_long)) {
346
- auto graph_and_mapping = BuildHybridGraph (new_mod, g->block (), cfg, static_params, first_use_types);
389
+ // The user did not require full compilation, and the model can be fully compiled
390
+ // or, the user required full compilation but the I/O of the graph use collections
391
+ if ((cfg.partitioning_info .enabled && isPartitioningRequired) || requires_collection_handling) {
392
+ // If the model is fully-compilable and the user has specified full compilation, run partitioning
393
+ // to generate collection-processing code in Torch
394
+ auto expect_full_compilation = (requires_collection_handling && !cfg.partitioning_info .enabled );
395
+
396
+ auto graph_and_mapping =
397
+ BuildHybridGraph (new_mod, g->block (), cfg, static_params, first_use_types, expect_full_compilation);
347
398
new_g = graph_and_mapping.first ;
348
399
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
349
400
for (size_t i = 0 ; i < new_g->inputs ().size (); ++i) {
0 commit comments