@@ -286,33 +286,24 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
286
286
}
287
287
288
288
bool QNNExecutionProvider::IsNodeSupported (qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
289
- std::unordered_map<const NodeUnit*, bool >& node_unit_supported_result,
290
289
const logging::Logger& logger) const {
291
- // If we have visited one of the nodes in the node_unit, use the result directly
292
- const auto it = node_unit_supported_result.find (&node_unit);
293
- if (it != node_unit_supported_result.cend ()) {
294
- return it->second ;
290
+ const std::string& op_type = node_unit.OpType ();
291
+ bool supported = false ;
292
+ const auto * op_builder = qnn::GetOpBuilder (op_type);
293
+ if (op_builder == nullptr ) {
294
+ LOGS (logger, WARNING) << " Operators of type `" << node_unit.OpType () << " ` are not supported by QNN EP."
295
+ << node_unit.OpType () << " node `" << node_unit.Name ()
296
+ << " ` will not be assigned to QNN EP." ;
295
297
} else {
296
- const std::string& op_type = node_unit.OpType ();
297
-
298
- bool supported = false ;
299
- const auto * op_builder = qnn::GetOpBuilder (op_type);
300
- if (op_builder == nullptr ) {
301
- LOGS (logger, WARNING) << " Operators of type `" << node_unit.OpType () << " ` are not supported by QNN EP."
302
- << node_unit.OpType () << " node `" << node_unit.Name ()
303
- << " ` will not be assigned to QNN EP." ;
304
- } else {
305
- auto status = op_builder->IsOpSupported (qnn_model_wrapper,
306
- node_unit, logger);
307
- if (Status::OK () != status) {
308
- LOGS (logger, WARNING) << node_unit.OpType () << " node `" << node_unit.Name ()
309
- << " ` is not supported: " << status.ErrorMessage ();
310
- }
311
- supported = (Status::OK () == status);
298
+ auto status = op_builder->IsOpSupported (qnn_model_wrapper,
299
+ node_unit, logger);
300
+ if (Status::OK () != status) {
301
+ LOGS (logger, WARNING) << node_unit.OpType () << " node `" << node_unit.Name ()
302
+ << " ` is not supported: " << status.ErrorMessage ();
312
303
}
313
- node_unit_supported_result[&node_unit] = supported;
314
- return supported;
304
+ supported = (Status::OK () == status);
315
305
}
306
+ return supported;
316
307
}
317
308
318
309
std::unordered_set<const Node*>
@@ -391,24 +382,51 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer,
391
382
if (node != &node_unit->GetNode ()) {
392
383
continue ;
393
384
}
394
- const bool supported = IsNodeSupported (qnn_model_wrapper,
395
- *node_unit,
396
- node_unit_supported_result,
397
- logger);
398
- LOGS (logger, VERBOSE) << " Node supported: [" << supported
399
- << " ] index: [" << node->Index ()
400
- << " ] name: [" << node->Name ()
401
- << " ] Operator type: [" << node->OpType ()
402
- << " ] as part of the NodeUnit type: [" << node_unit->OpType ()
403
- << " ] index: [" << node_unit->Index ()
404
- << " ] name: [" << node_unit->Name ()
405
- << " ]" ;
385
+
386
+ if (node_unit_supported_result.count (node_unit) != 0 ) {
387
+ continue ; // Already handled this node unit
388
+ }
389
+
390
+ // Try to convert certain standalone DQ -> Q sequences into QNN Convert op
391
+ auto convert_result = TryHandleConvertSequence (qnn_model_wrapper,
392
+ *node_unit,
393
+ node_unit_map,
394
+ logger,
395
+ true /* do_op_validation*/ );
396
+ if (!convert_result.status .IsOK ()) {
397
+ LOGS (logger, WARNING) << " Failed to convert DQ -> Q sequence to QNN Convert. "
398
+ << " Type: " << node_unit->OpType () << " , Node name: " << node_unit->Name () << " , "
399
+ << " Message: " << convert_result.status .ErrorMessage ();
400
+ }
401
+
402
+ bool supported = false ;
403
+
404
+ if (convert_result.status .IsOK () && convert_result.q_node_unit ) { // Merged DQ -> Q sequence into QNN Convert op
405
+ supported = true ;
406
+
407
+ // Mark the Q node unit as handled and supported here so that we don't try to process it again.
408
+ node_unit_supported_result.insert ({convert_result.q_node_unit , true });
409
+ supported_nodes.insert (&convert_result.q_node_unit ->GetNode ());
410
+ } else {
411
+ supported = IsNodeSupported (qnn_model_wrapper, *node_unit, logger);
412
+ LOGS (logger, VERBOSE) << " Node supported: [" << supported
413
+ << " ] index: [" << node->Index ()
414
+ << " ] name: [" << node->Name ()
415
+ << " ] Operator type: [" << node->OpType ()
416
+ << " ] as part of the NodeUnit type: [" << node_unit->OpType ()
417
+ << " ] index: [" << node_unit->Index ()
418
+ << " ] name: [" << node_unit->Name ()
419
+ << " ]" ;
420
+ }
421
+
406
422
if (supported) {
407
423
// If the node_unit is supported, add all of its nodes to the supported list.
408
424
for (const auto * node_in_group : node_unit->GetAllNodesInGroup ()) {
409
425
supported_nodes.insert (node_in_group);
410
426
}
411
427
}
428
+
429
+ node_unit_supported_result.insert ({node_unit, supported});
412
430
}
413
431
414
432
return supported_nodes;
0 commit comments