Skip to content

Commit b847121

Browse files
QNN EP: Fuse DQ -> Q sequences into a QNN Convert op (#19511)
### Description Fuses DQ -> Q sequences into a QNN Convert operator if: - Converting from one qtype to another. Ex: Dequantize(uint8 to float) -> Quantize(float to uint16) - The DQ and Q operators are not part of another node unit (i.e., standalone) - The Q operator is the only consumer for the DQ operator. ### Motivation and Context Allows faster execution of QDQ models with mixed activation types by leveraging the QNN Convert operator, which converts between quantization types. For certain models, this results in inference latency speed-ups of up to 2x (depends on the number of DQ -> Q sequences). #### Example for Add node unit with 16-bit I/O: Original: ``` u8 ----> DQ ---> Q ---u16--> Add ---u16--> ^ | u16 --------------------------+ ``` After fusing DQ -> Q: ``` u8 ----> Convert ---u16--> Add ---u16--> ^ | u16 ------------------------+ ```
1 parent ef0b713 commit b847121

File tree

8 files changed

+319
-41
lines changed

8 files changed

+319
-41
lines changed

Diff for: onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc

+43
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,49 @@ bool IsQDQPairSupported(
7676
}
7777
}
7878

79+
bool IsDQQConversion(
80+
const Node& dq_node, const Node& q_node,
81+
const GetConstantInitializerFn& get_const_initializer,
82+
const Path& model_path) {
83+
ConstPointerContainer<std::vector<NodeArg*>> dq_input_defs = dq_node.InputDefs();
84+
ConstPointerContainer<std::vector<NodeArg*>> q_input_defs = q_node.InputDefs();
85+
86+
// Q/DQ contains optional input is not supported
87+
// non-scalar Q/DQ scale and zero point needs are not supported
88+
if (dq_input_defs.size() != InputIndex::TOTAL_COUNT ||
89+
q_input_defs.size() != InputIndex::TOTAL_COUNT ||
90+
!optimizer_utils::IsScalar(*q_input_defs[InputIndex::SCALE_ID]) ||
91+
!optimizer_utils::IsScalar(*q_input_defs[InputIndex::ZERO_POINT_ID]) ||
92+
!optimizer_utils::IsScalar(*dq_input_defs[InputIndex::SCALE_ID]) ||
93+
!optimizer_utils::IsScalar(*dq_input_defs[InputIndex::ZERO_POINT_ID])) {
94+
return false;
95+
}
96+
97+
// if Q/DQ scale and zero point are not constant, return false
98+
const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto =
99+
get_const_initializer(dq_input_defs[InputIndex::SCALE_ID]->Name());
100+
const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto =
101+
get_const_initializer(q_input_defs[InputIndex::SCALE_ID]->Name());
102+
const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto =
103+
get_const_initializer(dq_input_defs[InputIndex::ZERO_POINT_ID]->Name());
104+
const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto =
105+
get_const_initializer(q_input_defs[InputIndex::ZERO_POINT_ID]->Name());
106+
if (nullptr == q_zp_tensor_proto ||
107+
nullptr == dq_zp_tensor_proto ||
108+
nullptr == q_scale_tensor_proto ||
109+
nullptr == dq_scale_tensor_proto) {
110+
return false;
111+
}
112+
113+
// check Q/DQ have same scale type and different zero point type
114+
Initializer q_zp(*q_zp_tensor_proto, model_path);
115+
Initializer q_scale(*q_scale_tensor_proto, model_path);
116+
Initializer dq_zp(*dq_zp_tensor_proto, model_path);
117+
Initializer dq_scale(*dq_scale_tensor_proto, model_path);
118+
119+
return (dq_zp.data_type() != q_zp.data_type()) && (dq_scale.data_type() == q_scale.data_type());
120+
}
121+
79122
bool IsDQSupported(const Node& dq_node, const GetConstantInitializerFn& get_const_initializer) {
80123
bool zero_point_exists = false;
81124
if (!QOrDQNodeHasConstantScalarScaleAndZeroPoint(dq_node, get_const_initializer, zero_point_exists)) {

Diff for: onnxruntime/core/optimizer/qdq_transformer/qdq_util.h

+12
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,18 @@ bool IsQDQPairSupported(
3838
const GetConstantInitializerFn& get_const_initializer,
3939
const Path& model_path);
4040

41+
// Check if a DQ -> Q sequence represents a conversion in quantization data type.
42+
// Example of uint8 to uint16:
43+
// Dequantize (uint8 to float) -> Quantize (float to uint16)
44+
// Requires:
45+
// 1. Q/DQ doesn't have optional input.
46+
// 2. scale and zero-point are constant scalars.
47+
// 3. Q and DQ have the same scale *type* and different zero-point *types*.
48+
bool IsDQQConversion(
49+
const Node& dq_node, const Node& q_node,
50+
const GetConstantInitializerFn& get_const_initializer,
51+
const Path& model_path);
52+
4153
// Check if DQ is supported in extended level QDQ transformers. It requires:
4254
// 1. DQ doesn't have optional input.
4355
// 2. scale and zero point is constant scalar

Diff for: onnxruntime/core/providers/qnn/builder/op_builder_factory.h

+23
Original file line numberDiff line numberDiff line change
@@ -94,5 +94,28 @@ void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r
9494

9595
void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
9696

97+
struct HandleConvertResult {
98+
Status status; // Indicates an unexpected error. Check if q_node_unit != nullptr to determine
99+
// whether a DQ -> Q sequence was successfully merged into a Convert.
100+
const NodeUnit* q_node_unit; // Non-null if successfully merged DQ -> Q sequence.
101+
// Set to nullptr if this node unit could not be merged into a Convert.
102+
};
103+
104+
/**
105+
* Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from
106+
* one quantization type (e.g., uint8_t) to another (e.g., uint16_t).
107+
*
108+
* \param qnn_model_wrapper The QNN model that is being built.
109+
* \param maybe_dq_node_unit The node unit that could potentially start the DQ -> Q sequence.
110+
* \param logger The logger.
111+
* \param do_op_validation True if should call QNN operator validation APIs.
112+
* \return An qnn::HandleConvertResult object that indicates success/failure and provides a pointer
113+
* to the Q node unit that was successfully merged with the provided DQ node unit.
114+
*/
115+
HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper,
116+
const NodeUnit& maybe_dq_node_unit,
117+
const std::unordered_map<const Node*, const NodeUnit*>& node_unit_map,
118+
const logging::Logger& logger,
119+
bool do_op_validation);
97120
} // namespace qnn
98121
} // namespace onnxruntime
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/graph/graph_utils.h"
5+
#include "core/optimizer/qdq_transformer/qdq_util.h"
6+
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
7+
#include "core/providers/shared/utils/utils.h"
8+
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
9+
#include "core/providers/qnn/builder/op_builder_factory.h"
10+
#include "core/common/safeint.h"
11+
#include "onnx/defs/data_type_utils.h"
12+
13+
#include "QnnOpDef.h" // From QNN SDK: contains QNN constants (e.g., op names, param values).
14+
15+
namespace onnxruntime {
16+
namespace qnn {
17+
18+
class ConvertOpBuilder : public BaseOpBuilder {
19+
public:
20+
ConvertOpBuilder() : BaseOpBuilder("ConvertOpBuilder") {}
21+
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConvertOpBuilder);
22+
23+
Status AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper,
24+
const NodeUnit& dq_node_unit,
25+
const NodeUnit& q_node_unit,
26+
const logging::Logger& logger,
27+
bool do_op_validation) const ORT_MUST_USE_RESULT;
28+
};
29+
30+
Status ConvertOpBuilder::AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper,
31+
const NodeUnit& dq_node_unit,
32+
const NodeUnit& q_node_unit,
33+
const logging::Logger& logger,
34+
bool do_op_validation) const {
35+
std::vector<std::string> input_names;
36+
37+
// Process the input from the DQ node
38+
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, dq_node_unit.Inputs()[0], logger, input_names));
39+
40+
// Process the output from the Q node. Override the QNN operator type to "Convert".
41+
ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, q_node_unit, std::move(input_names), {},
42+
logger, do_op_validation, QNN_OP_CONVERT));
43+
return Status::OK();
44+
}
45+
46+
HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper,
47+
const NodeUnit& maybe_dq_node_unit,
48+
const std::unordered_map<const Node*, const NodeUnit*>& node_unit_map,
49+
const logging::Logger& logger,
50+
bool do_op_validation) {
51+
const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer();
52+
53+
// Looking for a standalone DQ to start the sequence.
54+
if (maybe_dq_node_unit.OpType() != QDQ::DQOpName || maybe_dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) {
55+
return {};
56+
}
57+
58+
const Node& dq_node = maybe_dq_node_unit.GetNode();
59+
60+
// DQ must have a single Q child. DQ must not produce a graph output.
61+
auto children = graph_utils::FindChildrenByType(dq_node, QDQ::QOpName);
62+
if (children.size() != 1 || dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) {
63+
return {};
64+
}
65+
66+
const Node& q_node = *children[0];
67+
const auto q_node_unit_it = node_unit_map.find(&q_node);
68+
69+
if (q_node_unit_it == node_unit_map.end()) {
70+
return {ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node does not have a corresponding NodeUnit"), nullptr};
71+
}
72+
73+
const NodeUnit* q_node_unit = q_node_unit_it->second;
74+
75+
// Q child must not already be part of a QDQ NodeUnit (i.e., be standalone).
76+
if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) {
77+
return {};
78+
}
79+
80+
auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) {
81+
return graph_viewer.GetConstantInitializer(initializer_name, true);
82+
};
83+
84+
// DQ and Q must have equal scale type and different zp type.
85+
if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) {
86+
return {};
87+
}
88+
89+
ConvertOpBuilder op_builder;
90+
91+
LOGS(logger, VERBOSE) << " Adding QNN Convert. dq_node name: [" << dq_node.Name()
92+
<< "] dq_node optype: [" << dq_node.OpType()
93+
<< "] q_node name: [" << q_node_unit->Name()
94+
<< "] q_node optype: [" << q_node_unit->OpType()
95+
<< "]";
96+
97+
auto status = op_builder.AddConvertToModelBuilder(qnn_model_wrapper, maybe_dq_node_unit, *q_node_unit, logger,
98+
do_op_validation);
99+
return status.IsOK() ? HandleConvertResult{status, q_node_unit} : HandleConvertResult{status, nullptr};
100+
}
101+
102+
} // namespace qnn
103+
} // namespace onnxruntime

Diff for: onnxruntime/core/providers/qnn/builder/qnn_model.cc

+30-5
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,
114114
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to initialize qnn_model_wrapper.");
115115
}
116116

117+
std::unordered_set<const NodeUnit*> handled_node_units;
118+
117119
// Op builer
118120
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
119121
for (size_t i = 0; i < node_indices.size(); i++) {
@@ -122,20 +124,43 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,
122124
// Check whether it's part of NodeUnit
123125
const NodeUnit& node_unit = GetNodeUnit(node, node_unit_map);
124126
// Q, DQ nodes in the node unit only carry the quantization parameters
125-
// Add the QNN node when it is the target node (It's a normal node or a singel Q/DQ node)
127+
// Add the QNN node when it is the target node (It's a normal node or a single Q/DQ node)
126128
const std::string& op_type = node_unit.OpType();
129+
130+
if (node != &node_unit.GetNode()) {
131+
continue;
132+
}
133+
134+
if (handled_node_units.count(&node_unit) != 0) {
135+
continue; // Already handled.
136+
}
137+
138+
// Try to convert particular DQ -> Q sequences into QNN Convert op
139+
auto convert_result = TryHandleConvertSequence(qnn_model_wrapper,
140+
node_unit,
141+
node_unit_map,
142+
logger_,
143+
false /*do_op_validation*/);
144+
ORT_RETURN_IF_ERROR(convert_result.status);
145+
146+
if (convert_result.q_node_unit) {
147+
// Successfully merged DQ -> Q sequence into a QNN Convert op.
148+
// Mark both of these node units as handled.
149+
handled_node_units.insert(&node_unit);
150+
handled_node_units.insert(convert_result.q_node_unit);
151+
continue;
152+
}
153+
127154
LOGS(logger_, VERBOSE) << " node name: [" << node->Name()
128155
<< "] node optype: [" << op_type
129156
<< "] as part of the NodeUnit type: [" << node_unit.OpType()
130157
<< "] name: [" << node_unit.Name()
131158
<< "]";
132-
if (node != &node_unit.GetNode()) {
133-
continue;
134-
}
135-
136159
if (const auto* op_builder = GetOpBuilder(op_type)) {
137160
ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(qnn_model_wrapper, node_unit, logger_));
138161
}
162+
163+
handled_node_units.insert(&node_unit);
139164
}
140165

141166
ORT_RETURN_IF_NOT(qnn_model_wrapper.ComposeQnnGraph(), "Failed to compose Qnn graph.");

Diff for: onnxruntime/core/providers/qnn/qnn_execution_provider.cc

+53-35
Original file line numberDiff line numberDiff line change
@@ -286,33 +286,24 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
286286
}
287287

288288
bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
289-
std::unordered_map<const NodeUnit*, bool>& node_unit_supported_result,
290289
const logging::Logger& logger) const {
291-
// If we have visited one of the nodes in the node_unit, use the result directly
292-
const auto it = node_unit_supported_result.find(&node_unit);
293-
if (it != node_unit_supported_result.cend()) {
294-
return it->second;
290+
const std::string& op_type = node_unit.OpType();
291+
bool supported = false;
292+
const auto* op_builder = qnn::GetOpBuilder(op_type);
293+
if (op_builder == nullptr) {
294+
LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP."
295+
<< node_unit.OpType() << " node `" << node_unit.Name()
296+
<< "` will not be assigned to QNN EP.";
295297
} else {
296-
const std::string& op_type = node_unit.OpType();
297-
298-
bool supported = false;
299-
const auto* op_builder = qnn::GetOpBuilder(op_type);
300-
if (op_builder == nullptr) {
301-
LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP."
302-
<< node_unit.OpType() << " node `" << node_unit.Name()
303-
<< "` will not be assigned to QNN EP.";
304-
} else {
305-
auto status = op_builder->IsOpSupported(qnn_model_wrapper,
306-
node_unit, logger);
307-
if (Status::OK() != status) {
308-
LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name()
309-
<< "` is not supported: " << status.ErrorMessage();
310-
}
311-
supported = (Status::OK() == status);
298+
auto status = op_builder->IsOpSupported(qnn_model_wrapper,
299+
node_unit, logger);
300+
if (Status::OK() != status) {
301+
LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name()
302+
<< "` is not supported: " << status.ErrorMessage();
312303
}
313-
node_unit_supported_result[&node_unit] = supported;
314-
return supported;
304+
supported = (Status::OK() == status);
315305
}
306+
return supported;
316307
}
317308

318309
std::unordered_set<const Node*>
@@ -391,24 +382,51 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer,
391382
if (node != &node_unit->GetNode()) {
392383
continue;
393384
}
394-
const bool supported = IsNodeSupported(qnn_model_wrapper,
395-
*node_unit,
396-
node_unit_supported_result,
397-
logger);
398-
LOGS(logger, VERBOSE) << "Node supported: [" << supported
399-
<< "] index: [" << node->Index()
400-
<< "] name: [" << node->Name()
401-
<< "] Operator type: [" << node->OpType()
402-
<< "] as part of the NodeUnit type: [" << node_unit->OpType()
403-
<< "] index: [" << node_unit->Index()
404-
<< "] name: [" << node_unit->Name()
405-
<< "]";
385+
386+
if (node_unit_supported_result.count(node_unit) != 0) {
387+
continue; // Already handled this node unit
388+
}
389+
390+
// Try to convert certain standalone DQ -> Q sequences into QNN Convert op
391+
auto convert_result = TryHandleConvertSequence(qnn_model_wrapper,
392+
*node_unit,
393+
node_unit_map,
394+
logger,
395+
true /*do_op_validation*/);
396+
if (!convert_result.status.IsOK()) {
397+
LOGS(logger, WARNING) << "Failed to convert DQ -> Q sequence to QNN Convert. "
398+
<< "Type: " << node_unit->OpType() << ", Node name: " << node_unit->Name() << ", "
399+
<< "Message: " << convert_result.status.ErrorMessage();
400+
}
401+
402+
bool supported = false;
403+
404+
if (convert_result.status.IsOK() && convert_result.q_node_unit) { // Merged DQ -> Q sequence into QNN Convert op
405+
supported = true;
406+
407+
// Mark the Q node unit as handled and supported here so that we don't try to process it again.
408+
node_unit_supported_result.insert({convert_result.q_node_unit, true});
409+
supported_nodes.insert(&convert_result.q_node_unit->GetNode());
410+
} else {
411+
supported = IsNodeSupported(qnn_model_wrapper, *node_unit, logger);
412+
LOGS(logger, VERBOSE) << "Node supported: [" << supported
413+
<< "] index: [" << node->Index()
414+
<< "] name: [" << node->Name()
415+
<< "] Operator type: [" << node->OpType()
416+
<< "] as part of the NodeUnit type: [" << node_unit->OpType()
417+
<< "] index: [" << node_unit->Index()
418+
<< "] name: [" << node_unit->Name()
419+
<< "]";
420+
}
421+
406422
if (supported) {
407423
// If the node_unit is supported, add all of its nodes to the supported list.
408424
for (const auto* node_in_group : node_unit->GetAllNodesInGroup()) {
409425
supported_nodes.insert(node_in_group);
410426
}
411427
}
428+
429+
node_unit_supported_result.insert({node_unit, supported});
412430
}
413431

414432
return supported_nodes;

Diff for: onnxruntime/core/providers/qnn/qnn_execution_provider.h

-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ class QNNExecutionProvider : public IExecutionProvider {
4242

4343
private:
4444
bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
45-
std::unordered_map<const NodeUnit*, bool>& node_unit_supported_result,
4645
const logging::Logger& logger) const;
4746

4847
std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewer,

0 commit comments

Comments
 (0)