Skip to content

Commit 0462dc3

Browse files
committed
Mark user-supplised plugins as supported in ONNXRT-TRT
Signed-off-by: Kevin Chen <[email protected]>
1 parent 6ba67d3 commit 0462dc3

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

ModelImporter.cpp

+14-17
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,17 @@ Status deserialize_onnx_model(int32_t fd, bool is_serialized_as_text, ::ONNX_NAM
373373
return Status::success();
374374
}
375375

376+
// Internal helper function used for ONNXRT-TRT EP to filter out DDS nodes
377+
bool isDDSOp(char const* op_name)
378+
{
379+
auto is = [op_name](char const* name) { return std::strcmp(op_name, name) == 0; };
380+
if (is("NonMaxSuppression") || is("NonZero") || is("RoiAlign"))
381+
{
382+
return true;
383+
}
384+
return false;
385+
}
386+
376387
bool ModelImporter::supportsModel(void const* serialized_onnx_model, size_t serialized_onnx_model_size,
377388
SubGraphCollection_t& sub_graph_collection, char const* model_path)
378389
{
@@ -446,13 +457,13 @@ bool ModelImporter::supportsModel(void const* serialized_onnx_model, size_t seri
446457
{
447458
::ONNX_NAMESPACE::NodeProto const& node = model.graph().node(node_idx);
448459
// Add the node to the subgraph if:
449-
// 1. There is an importer function registered for the operator type
460+
// 1. It is not a node that requires DDS
450461
// 2. It is not directly connected to an unsupported input
451462
// 3. The importer function did not throw an assertion
452-
bool registered = supportsOperator(node.op_type().c_str());
463+
bool unsupportedDDS = isDDSOp(node.op_type().c_str());
453464
bool unsupportedInput = (input_node.empty()) ? false : checkForInput(node);
454465
bool unsuccessfulParse = node_idx == error_node;
455-
if (registered && !unsupportedInput && !unsuccessfulParse)
466+
if (!unsupportedDDS && !unsupportedInput && !unsuccessfulParse)
456467
{
457468
if (newSubGraph)
458469
{
@@ -481,22 +492,8 @@ bool ModelImporter::supportsModel(void const* serialized_onnx_model, size_t seri
481492
return allSupported;
482493
}
483494

484-
// This funciton is used by ONNXRT to partition out unsupported nodes
485495
bool ModelImporter::supportsOperator(char const* op_name) const
486496
{
487-
auto is = [op_name](char const* name) { return std::strcmp(op_name, name) == 0; };
488-
489-
// Mark these following plugins as supported
490-
if (is("EfficientNMS_TRT") || is("PyramidROIAlign_TRT") || is("MultilevelCropAndResize_TRT")
491-
|| is("DisentangledAttention_TRT"))
492-
{
493-
return true;
494-
}
495-
// Disable nodes that rely on DDS as ONNXRuntime does not support it at the moment
496-
if (is("NonMaxSuppression") || is("NonZero") || is("RoiAlign"))
497-
{
498-
return false;
499-
}
500497
return _op_importers.count(op_name);
501498
}
502499

0 commit comments

Comments
 (0)