Skip to content

Commit 26cf74e

Browse files
centwangsnnn
authored andcommitted
[QNN] MatMul Op Builder to Handle All Cases of ONNX's MatMul (#22639)
ONNX's MatMul is same as numpy.matmul, which supports input tensors with rank >= 1. But QNN's MatMul can only support input tensors with rank >= 2. This PR is to add MatMulOpBuilder for QNN EP to build QNN graph to support all possible cases of ONNX's MatMul, by adding Reshape nodes if necessary, e.g., if Reshape 1D input to 2D if exists, and Reshape output to expected shape at the end.   This PR also tries to use FullyConnected Op for MatMul if 2nd input is 2D initializer or 1D tensor because FullyConnected is faster than MatMul on QNN EP. If 2nd input is 2D tensor, we require it an initializer because FullyConnected requires 2nd input in [n, k] shape, we can transpose it when graph building if it's an initializer (we don't want to add extra Transpose node). Use swin_base model as example, which contains several MatMul nodes with 2nd input is 2D initializer (not followed by Add), running on Gen3 mobile device, before the change, it takes 34.8876 ms, after this change, it's 27.0639 ms.
1 parent 5ce797a commit 26cf74e

File tree

6 files changed

+436
-274
lines changed

6 files changed

+436
-274
lines changed

onnxruntime/core/providers/qnn/builder/op_builder_factory.cc

+4-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
5151
CreateSimpleOpBuilder("Sub", *this);
5252
CreateSimpleOpBuilder("Tanh", *this);
5353

54-
CreateSimpleOpBuilder("MatMul", *this);
5554
CreateSimpleOpBuilder("Concat", *this);
5655

5756
CreateSimpleOpBuilder("QuantizeLinear", *this);
@@ -170,6 +169,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
170169
{
171170
CreateExpandOpBuilder("Expand", *this);
172171
}
172+
173+
{
174+
CreateMatMulOpBuilder("MatMul", *this);
175+
}
173176
}
174177

175178
const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) {

onnxruntime/core/providers/qnn/builder/op_builder_factory.h

+2
Original file line numberDiff line numberDiff line change
@@ -96,5 +96,7 @@ void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r
9696
void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
9797

9898
void CreateHardSigmoidOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
99+
100+
void CreateMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
99101
} // namespace qnn
100102
} // namespace onnxruntime
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/common/safeint.h"
5+
#include "core/providers/common.h"
6+
#include "core/providers/qnn/builder/op_builder_factory.h"
7+
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
8+
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
9+
#include "core/providers/qnn/builder/qnn_utils.h"
10+
#include "core/providers/shared/utils/utils.h"
11+
12+
namespace onnxruntime {
13+
namespace qnn {
14+
15+
/**
16+
* ONNX's MatMul supports 1D tensor as input on both size, but neither QNN's MatMul nor FullyConnected supports it.
17+
* So we need to add Reshape Ops if necessary.
18+
* In two cases, FullyConnected (input_1's shape is [n, k]) is used instead of MatMul without extra Transpose Op:
19+
* 1. input_1 is 2D initializer.
20+
* 2. input_1 is 1D tensor.
21+
*/
22+
class MatMulOpBuilder : public BaseOpBuilder {
23+
public:
24+
MatMulOpBuilder() : BaseOpBuilder("MatMulOpBuilder") {}
25+
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(MatMulOpBuilder);
26+
27+
protected:
28+
Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger,
29+
std::vector<std::string>& input_names, bool do_op_validation) const override ORT_MUST_USE_RESULT;
30+
31+
Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
32+
std::vector<std::string>&& input_names, const logging::Logger& logger,
33+
bool do_op_validation) const override ORT_MUST_USE_RESULT;
34+
};
35+
36+
namespace {
37+
38+
Status CheckInputs(const QnnModelWrapper& qnn_model_wrapper, const NodeUnitIODef& input_def_0,
39+
const NodeUnitIODef& input_def_1, TensorInfo& input_info_0, TensorInfo& input_info_1,
40+
bool& use_fully_connected) {
41+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input_def_0, input_info_0));
42+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input_def_1, input_info_1));
43+
44+
// Use FullyConnected if 2nd input is 2D initializer or 1D tensor.
45+
// FullyConnected cannot pass the Op validation if keep_dims is true, so if input_0 is per-channel quantized tensor
46+
// with rank > 2, it's not easy to set the quantization parameters for the output reshaped 2D tensor.
47+
// In this case, we will not use FullyConnected.
48+
use_fully_connected =
49+
(input_info_1.shape.size() == 2 && input_info_1.is_initializer) || input_info_1.shape.size() == 1;
50+
use_fully_connected =
51+
use_fully_connected && !(input_info_0.quant_param.IsPerChannel() && input_info_0.shape.size() > 2);
52+
return Status::OK();
53+
}
54+
55+
} // namespace
56+
57+
Status MatMulOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
58+
const logging::Logger& logger, std::vector<std::string>& input_names,
59+
bool do_op_validation) const {
60+
const auto& inputs = node_unit.Inputs();
61+
TensorInfo input_info_0{};
62+
TensorInfo input_info_1{};
63+
bool use_fully_connected = false;
64+
ORT_RETURN_IF_ERROR(
65+
CheckInputs(qnn_model_wrapper, inputs[0], inputs[1], input_info_0, input_info_1, use_fully_connected));
66+
bool reshape_input_0 = input_info_0.shape.size() == 1;
67+
bool reshape_input_1 = input_info_1.shape.size() == 1;
68+
69+
// Process input 0.
70+
const std::string& org_input_0_name = inputs[0].node_arg.Name();
71+
std::string input_0_name = org_input_0_name;
72+
if (reshape_input_0) {
73+
input_0_name = org_input_0_name + "_ort_qnn_ep_reshape";
74+
std::vector<uint32_t> shape_2d{1, input_info_0.shape[0]};
75+
QnnQuantParamsWrapper quant_param_2d = input_info_0.quant_param.Copy();
76+
ORT_RETURN_IF_ERROR(quant_param_2d.HandleUnsqueeze<uint32_t>(input_info_0.shape, shape_2d));
77+
78+
// If input_0 is initializer, unpack it and add the tensor with new quantization parameter and shape.
79+
// Otherwise, add a Reshape node.
80+
if (input_info_0.is_initializer) {
81+
std::vector<uint8_t> unpacked_tensor;
82+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_info_0.initializer_tensor, unpacked_tensor));
83+
Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType(org_input_0_name);
84+
QnnTensorWrapper input_tensorwrapper(input_0_name, tensor_type, input_info_0.qnn_data_type,
85+
std::move(quant_param_2d), std::move(shape_2d), std::move(unpacked_tensor));
86+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");
87+
} else {
88+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(org_input_0_name, input_0_name, input_info_0.shape, shape_2d,
89+
input_info_0.qnn_data_type, input_info_0.quant_param,
90+
quant_param_2d, do_op_validation,
91+
qnn_model_wrapper.IsGraphInput(org_input_0_name), false));
92+
}
93+
} else {
94+
if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_0_name)) {
95+
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_0_name;
96+
} else {
97+
QnnTensorWrapper input_0_tensor;
98+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[0], input_0_tensor));
99+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_0_tensor)), "Failed to add tensor.");
100+
}
101+
}
102+
input_names.emplace_back(input_0_name);
103+
104+
// Process input 1.
105+
const std::string& org_input_1_name = inputs[1].node_arg.Name();
106+
std::string input_1_name = org_input_1_name;
107+
if (reshape_input_1 || use_fully_connected) {
108+
std::vector<uint32_t> shape_2d;
109+
QnnQuantParamsWrapper quant_param_2d = input_info_1.quant_param.Copy();
110+
if (reshape_input_1) {
111+
// Input is 1D tensor.
112+
input_1_name = org_input_1_name + "_ort_qnn_ep_reshape";
113+
if (use_fully_connected) {
114+
// FullyConnected requires input_1's shape to be [n, k].
115+
shape_2d = {1, input_info_1.shape[0]};
116+
} else {
117+
shape_2d = {input_info_1.shape[0], 1};
118+
}
119+
ORT_RETURN_IF_ERROR(quant_param_2d.HandleUnsqueeze<uint32_t>(input_info_1.shape, shape_2d));
120+
} else {
121+
input_1_name = org_input_1_name + "_ort_qnn_ep_transpose";
122+
shape_2d = {input_info_1.shape[1], input_info_1.shape[0]};
123+
ORT_RETURN_IF_ERROR(quant_param_2d.HandleTranspose<uint32_t>(std::vector<uint32_t>({1, 0})));
124+
}
125+
126+
// If input_1 is initializer, unpack it and add the tensor with new quantization parameter and shape.
127+
// Otherwise, add a Reshape node.
128+
if (input_info_1.is_initializer) {
129+
std::vector<uint8_t> unpacked_tensor;
130+
if (use_fully_connected && !reshape_input_1) {
131+
// 2D initializer should be transposed to [n, k].
132+
ORT_RETURN_IF_ERROR(TwoDimensionTranspose(qnn_model_wrapper, input_info_1.shape,
133+
*input_info_1.initializer_tensor, unpacked_tensor));
134+
} else {
135+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_info_1.initializer_tensor, unpacked_tensor));
136+
}
137+
138+
Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType(org_input_1_name);
139+
QnnTensorWrapper input_tensorwrapper(input_1_name, tensor_type, input_info_1.qnn_data_type,
140+
std::move(quant_param_2d), std::move(shape_2d), std::move(unpacked_tensor));
141+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");
142+
} else {
143+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(org_input_1_name, input_1_name, input_info_1.shape, shape_2d,
144+
input_info_1.qnn_data_type, input_info_1.quant_param,
145+
quant_param_2d, do_op_validation,
146+
qnn_model_wrapper.IsGraphInput(org_input_1_name), false));
147+
}
148+
} else {
149+
if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_1_name)) {
150+
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_1_name;
151+
} else {
152+
QnnTensorWrapper input_1_tensor;
153+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[1], input_1_tensor));
154+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_1_tensor)), "Failed to add tensor.");
155+
}
156+
}
157+
input_names.emplace_back(input_1_name);
158+
159+
return Status::OK();
160+
}
161+
162+
Status MatMulOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
163+
std::vector<std::string>&& input_names,
164+
const logging::Logger& /*logger*/, bool do_op_validation) const {
165+
const auto& inputs = node_unit.Inputs();
166+
TensorInfo input_info_0{};
167+
TensorInfo input_info_1{};
168+
bool use_fully_connected = false;
169+
ORT_RETURN_IF_ERROR(
170+
CheckInputs(qnn_model_wrapper, inputs[0], inputs[1], input_info_0, input_info_1, use_fully_connected));
171+
bool reshape_input_0 = input_info_0.shape.size() == 1;
172+
bool reshape_input_1 = input_info_1.shape.size() == 1;
173+
bool reshape_output = reshape_input_0 || reshape_input_1 || (use_fully_connected && input_info_0.shape.size() > 2);
174+
175+
const std::string& org_output_name = node_unit.Outputs()[0].node_arg.Name();
176+
std::string op_output_name = org_output_name;
177+
TensorInfo output_info{};
178+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info));
179+
std::vector<uint32_t> op_output_shape = output_info.shape;
180+
QnnQuantParamsWrapper op_output_quant_param = output_info.quant_param.Copy();
181+
if (reshape_output) {
182+
op_output_name = org_output_name + "_ort_qnn_ep_reshape";
183+
if (use_fully_connected && input_info_0.shape.size() > 2) {
184+
op_output_shape = {std::accumulate(input_info_0.shape.begin(), input_info_0.shape.end() - 1,
185+
static_cast<uint32_t>(1), std::multiplies<uint32_t>()),
186+
reshape_input_1 ? 1 : input_info_1.shape.back()};
187+
ORT_ENFORCE(!op_output_quant_param.IsPerChannel());
188+
} else {
189+
// If both inputs are 1D tensors, the output shape is [1] instead of scalar. So if both inputs are 1D tensors,
190+
// we only need to add one "1" to the op_output_shape.
191+
if (reshape_input_1) {
192+
op_output_shape.emplace_back(1);
193+
} else if (reshape_input_0) {
194+
op_output_shape.insert(op_output_shape.end() - 1, 1);
195+
}
196+
ORT_RETURN_IF_ERROR(op_output_quant_param.HandleUnsqueeze<uint32_t>(output_info.shape, op_output_shape));
197+
}
198+
}
199+
200+
const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(org_output_name);
201+
const bool is_op_output_graph_output = is_graph_output && !reshape_output;
202+
Qnn_TensorType_t op_output_tensor_type =
203+
is_op_output_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
204+
QnnTensorWrapper op_output_tensor_wrapper(op_output_name, op_output_tensor_type, output_info.qnn_data_type,
205+
op_output_quant_param.Copy(), std::vector<uint32_t>(op_output_shape));
206+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(op_output_tensor_wrapper)),
207+
"Failed to add output tensor.");
208+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW,
209+
use_fully_connected ? QNN_OP_FULLY_CONNECTED : QNN_OP_MAT_MUL,
210+
std::move(input_names), {op_output_name}, {}, do_op_validation),
211+
"Failed to add fused Matmul node.");
212+
213+
if (reshape_output) {
214+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(
215+
op_output_name, org_output_name, op_output_shape, output_info.shape, output_info.qnn_data_type,
216+
op_output_quant_param, output_info.quant_param, do_op_validation, false, is_graph_output));
217+
}
218+
219+
return Status::OK();
220+
}
221+
222+
void CreateMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
223+
op_registrations.AddOpBuilder(op_type, std::make_unique<MatMulOpBuilder>());
224+
}
225+
226+
} // namespace qnn
227+
} // namespace onnxruntime

onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc

+22-26
Original file line numberDiff line numberDiff line change
@@ -495,49 +495,45 @@ Status QnnModelWrapper::GetTensorInfo(const NodeUnitIODef& input, TensorInfo& te
495495
return Status::OK();
496496
}
497497

498-
Status QnnModelWrapper::AddReshapeNode(const std::string& input_name,
499-
const std::string& output_name,
498+
Status QnnModelWrapper::AddReshapeNode(const std::string& input_name, const std::string& output_name,
500499
const std::vector<uint32_t>& input_shape,
501500
const std::vector<uint32_t>& output_shape,
502501
const Qnn_DataType_t& tensor_data_type,
503-
const QnnQuantParamsWrapper& quantize_param,
504-
bool do_op_validation,
505-
bool is_for_input,
506-
bool is_for_output) {
507-
// Do not allow QNN EP to insert Reshape nodes with per-channel quantization on dynamic tensors.
508-
// We could technically support this by shifting the quantization param's axis value, but
509-
// we don't need this right now.
510-
ORT_RETURN_IF(quantize_param.IsPerChannel(),
511-
"Do not support inserted Reshape nodes with per-channel quantization");
512-
QnnTensorWrapper input_tensorwrapper(input_name,
513-
is_for_input ? QNN_TENSOR_TYPE_APP_WRITE : QNN_TENSOR_TYPE_NATIVE,
514-
tensor_data_type,
515-
quantize_param.Copy(),
502+
const QnnQuantParamsWrapper& input_quantize_param,
503+
const QnnQuantParamsWrapper& output_quantize_param, bool do_op_validation,
504+
bool is_for_input, bool is_for_output) {
505+
QnnTensorWrapper input_tensorwrapper(input_name, is_for_input ? QNN_TENSOR_TYPE_APP_WRITE : QNN_TENSOR_TYPE_NATIVE,
506+
tensor_data_type, input_quantize_param.Copy(),
516507
std::vector<uint32_t>(input_shape));
517508
ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(input_tensorwrapper)),
518509
"QNN EP: Failed to add input tensor for inserted Reshape.");
519510

520511
Qnn_TensorType_t tensor_type = is_for_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
521-
QnnTensorWrapper output_tensorwrapper(output_name,
522-
tensor_type,
523-
tensor_data_type,
524-
quantize_param.Copy(),
512+
QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, tensor_data_type, output_quantize_param.Copy(),
525513
std::vector<uint32_t>(output_shape));
526514
ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(output_tensorwrapper)),
527515
"QNN EP: Failed to add output tensor for inserted Reshape.");
528516

529-
ORT_RETURN_IF_NOT(CreateQnnNode(output_name,
530-
QNN_OP_PACKAGE_NAME_QTI_AISW,
531-
QNN_OP_RESHAPE,
532-
{input_name},
533-
{output_name},
534-
{},
535-
do_op_validation),
517+
ORT_RETURN_IF_NOT(CreateQnnNode(output_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_RESHAPE, {input_name},
518+
{output_name}, {}, do_op_validation),
536519
"QNN EP: Failed to create manually inserted Qnn Reshape node.");
537520

538521
return Status::OK();
539522
}
540523

524+
Status QnnModelWrapper::AddReshapeNode(const std::string& input_name, const std::string& output_name,
525+
const std::vector<uint32_t>& input_shape,
526+
const std::vector<uint32_t>& output_shape,
527+
const Qnn_DataType_t& tensor_data_type,
528+
const QnnQuantParamsWrapper& quantize_param, bool do_op_validation,
529+
bool is_for_input, bool is_for_output) {
530+
// Do not allow QNN EP to insert Reshape nodes with per-channel quantization on dynamic tensors
531+
// if only one quantization param is provided.
532+
ORT_RETURN_IF(quantize_param.IsPerChannel(), "Do not support inserted Reshape nodes with per-channel quantization");
533+
return AddReshapeNode(input_name, output_name, input_shape, output_shape, tensor_data_type, quantize_param,
534+
quantize_param, do_op_validation, is_for_input, is_for_output);
535+
}
536+
541537
Status QnnModelWrapper::AddTransposeNode(NodeIndex node_index,
542538
const std::string& input_name,
543539
const std::string& output_name,

onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h

+11
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,17 @@ class QnnModelWrapper {
141141

142142
Status GetTensorInfo(const NodeUnitIODef& input, TensorInfo& input_info) const;
143143

144+
Status AddReshapeNode(const std::string& input_name,
145+
const std::string& output_name,
146+
const std::vector<uint32_t>& input_shape,
147+
const std::vector<uint32_t>& output_shape,
148+
const Qnn_DataType_t& tensor_data_type,
149+
const QnnQuantParamsWrapper& input_quantize_param,
150+
const QnnQuantParamsWrapper& output_quantize_param,
151+
bool do_op_validation,
152+
bool is_for_input = true,
153+
bool is_for_output = false);
154+
144155
Status AddReshapeNode(const std::string& input_name,
145156
const std::string& output_name,
146157
const std::vector<uint32_t>& input_shape,

0 commit comments

Comments
 (0)