@@ -373,6 +373,17 @@ Status deserialize_onnx_model(int32_t fd, bool is_serialized_as_text, ::ONNX_NAM
373
373
return Status::success ();
374
374
}
375
375
376
+ // Internal helper function used for ONNXRT-TRT EP to filter out DDS nodes
377
+ bool isDDSOp (char const * op_name)
378
+ {
379
+ auto is = [op_name](char const * name) { return std::strcmp (op_name, name) == 0 ; };
380
+ if (is (" NonMaxSuppression" ) || is (" NonZero" ) || is (" RoiAlign" ))
381
+ {
382
+ return true ;
383
+ }
384
+ return false ;
385
+ }
386
+
376
387
bool ModelImporter::supportsModel (void const * serialized_onnx_model, size_t serialized_onnx_model_size,
377
388
SubGraphCollection_t& sub_graph_collection, char const * model_path)
378
389
{
@@ -446,13 +457,13 @@ bool ModelImporter::supportsModel(void const* serialized_onnx_model, size_t seri
446
457
{
447
458
::ONNX_NAMESPACE::NodeProto const & node = model.graph ().node (node_idx);
448
459
// Add the node to the subgraph if:
449
- // 1. There is an importer function registered for the operator type
460
+ // 1. It is not a node that requires DDS
450
461
// 2. It is not directly connected to an unsupported input
451
462
// 3. The importer function did not throw an assertion
452
- bool registered = supportsOperator (node.op_type ().c_str ());
463
+ bool unsupportedDDS = isDDSOp (node.op_type ().c_str ());
453
464
bool unsupportedInput = (input_node.empty ()) ? false : checkForInput (node);
454
465
bool unsuccessfulParse = node_idx == error_node;
455
- if (registered && !unsupportedInput && !unsuccessfulParse)
466
+ if (!unsupportedDDS && !unsupportedInput && !unsuccessfulParse)
456
467
{
457
468
if (newSubGraph)
458
469
{
@@ -481,22 +492,8 @@ bool ModelImporter::supportsModel(void const* serialized_onnx_model, size_t seri
481
492
return allSupported;
482
493
}
483
494
484
- // This funciton is used by ONNXRT to partition out unsupported nodes
485
495
bool ModelImporter::supportsOperator (char const * op_name) const
486
496
{
487
- auto is = [op_name](char const * name) { return std::strcmp (op_name, name) == 0 ; };
488
-
489
- // Mark these following plugins as supported
490
- if (is (" EfficientNMS_TRT" ) || is (" PyramidROIAlign_TRT" ) || is (" MultilevelCropAndResize_TRT" )
491
- || is (" DisentangledAttention_TRT" ))
492
- {
493
- return true ;
494
- }
495
- // Disable nodes that rely on DDS as ONNXRuntime does not support it at the moment
496
- if (is (" NonMaxSuppression" ) || is (" NonZero" ) || is (" RoiAlign" ))
497
- {
498
- return false ;
499
- }
500
497
return _op_importers.count (op_name);
501
498
}
502
499
0 commit comments