|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include "core/common/safeint.h" |
| 5 | +#include "core/providers/common.h" |
| 6 | +#include "core/providers/qnn/builder/op_builder_factory.h" |
| 7 | +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" |
| 8 | +#include "core/providers/qnn/builder/qnn_model_wrapper.h" |
| 9 | +#include "core/providers/qnn/builder/qnn_utils.h" |
| 10 | +#include "core/providers/shared/utils/utils.h" |
| 11 | + |
| 12 | +namespace onnxruntime { |
| 13 | +namespace qnn { |
| 14 | + |
| 15 | +/** |
| 16 | + * ONNX's MatMul supports 1D tensor as input on both size, but neither QNN's MatMul nor FullyConnected supports it. |
| 17 | + * So we need to add Reshape Ops if necessary. |
| 18 | + * In two cases, FullyConnected (input_1's shape is [n, k]) is used instead of MatMul without extra Transpose Op: |
| 19 | + * 1. input_1 is 2D initializer. |
| 20 | + * 2. input_1 is 1D tensor. |
| 21 | + */ |
| 22 | +class MatMulOpBuilder : public BaseOpBuilder { |
| 23 | + public: |
| 24 | + MatMulOpBuilder() : BaseOpBuilder("MatMulOpBuilder") {} |
| 25 | + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(MatMulOpBuilder); |
| 26 | + |
| 27 | + protected: |
| 28 | + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger, |
| 29 | + std::vector<std::string>& input_names, bool do_op_validation) const override ORT_MUST_USE_RESULT; |
| 30 | + |
| 31 | + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, |
| 32 | + std::vector<std::string>&& input_names, const logging::Logger& logger, |
| 33 | + bool do_op_validation) const override ORT_MUST_USE_RESULT; |
| 34 | +}; |
| 35 | + |
| 36 | +namespace { |
| 37 | + |
| 38 | +Status CheckInputs(const QnnModelWrapper& qnn_model_wrapper, const NodeUnitIODef& input_def_0, |
| 39 | + const NodeUnitIODef& input_def_1, TensorInfo& input_info_0, TensorInfo& input_info_1, |
| 40 | + bool& use_fully_connected) { |
| 41 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input_def_0, input_info_0)); |
| 42 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input_def_1, input_info_1)); |
| 43 | + |
| 44 | + // Use FullyConnected if 2nd input is 2D initializer or 1D tensor. |
| 45 | + // FullyConnected cannot pass the Op validation if keep_dims is true, so if input_0 is per-channel quantized tensor |
| 46 | + // with rank > 2, it's not easy to set the quantization parameters for the output reshaped 2D tensor. |
| 47 | + // In this case, we will not use FullyConnected. |
| 48 | + use_fully_connected = |
| 49 | + (input_info_1.shape.size() == 2 && input_info_1.is_initializer) || input_info_1.shape.size() == 1; |
| 50 | + use_fully_connected = |
| 51 | + use_fully_connected && !(input_info_0.quant_param.IsPerChannel() && input_info_0.shape.size() > 2); |
| 52 | + return Status::OK(); |
| 53 | +} |
| 54 | + |
| 55 | +} // namespace |
| 56 | + |
| 57 | +Status MatMulOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, |
| 58 | + const logging::Logger& logger, std::vector<std::string>& input_names, |
| 59 | + bool do_op_validation) const { |
| 60 | + const auto& inputs = node_unit.Inputs(); |
| 61 | + TensorInfo input_info_0{}; |
| 62 | + TensorInfo input_info_1{}; |
| 63 | + bool use_fully_connected = false; |
| 64 | + ORT_RETURN_IF_ERROR( |
| 65 | + CheckInputs(qnn_model_wrapper, inputs[0], inputs[1], input_info_0, input_info_1, use_fully_connected)); |
| 66 | + bool reshape_input_0 = input_info_0.shape.size() == 1; |
| 67 | + bool reshape_input_1 = input_info_1.shape.size() == 1; |
| 68 | + |
| 69 | + // Process input 0. |
| 70 | + const std::string& org_input_0_name = inputs[0].node_arg.Name(); |
| 71 | + std::string input_0_name = org_input_0_name; |
| 72 | + if (reshape_input_0) { |
| 73 | + input_0_name = org_input_0_name + "_ort_qnn_ep_reshape"; |
| 74 | + std::vector<uint32_t> shape_2d{1, input_info_0.shape[0]}; |
| 75 | + QnnQuantParamsWrapper quant_param_2d = input_info_0.quant_param.Copy(); |
| 76 | + ORT_RETURN_IF_ERROR(quant_param_2d.HandleUnsqueeze<uint32_t>(input_info_0.shape, shape_2d)); |
| 77 | + |
| 78 | + // If input_0 is initializer, unpack it and add the tensor with new quantization parameter and shape. |
| 79 | + // Otherwise, add a Reshape node. |
| 80 | + if (input_info_0.is_initializer) { |
| 81 | + std::vector<uint8_t> unpacked_tensor; |
| 82 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_info_0.initializer_tensor, unpacked_tensor)); |
| 83 | + Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType(org_input_0_name); |
| 84 | + QnnTensorWrapper input_tensorwrapper(input_0_name, tensor_type, input_info_0.qnn_data_type, |
| 85 | + std::move(quant_param_2d), std::move(shape_2d), std::move(unpacked_tensor)); |
| 86 | + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); |
| 87 | + } else { |
| 88 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(org_input_0_name, input_0_name, input_info_0.shape, shape_2d, |
| 89 | + input_info_0.qnn_data_type, input_info_0.quant_param, |
| 90 | + quant_param_2d, do_op_validation, |
| 91 | + qnn_model_wrapper.IsGraphInput(org_input_0_name), false)); |
| 92 | + } |
| 93 | + } else { |
| 94 | + if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_0_name)) { |
| 95 | + LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_0_name; |
| 96 | + } else { |
| 97 | + QnnTensorWrapper input_0_tensor; |
| 98 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[0], input_0_tensor)); |
| 99 | + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_0_tensor)), "Failed to add tensor."); |
| 100 | + } |
| 101 | + } |
| 102 | + input_names.emplace_back(input_0_name); |
| 103 | + |
| 104 | + // Process input 1. |
| 105 | + const std::string& org_input_1_name = inputs[1].node_arg.Name(); |
| 106 | + std::string input_1_name = org_input_1_name; |
| 107 | + if (reshape_input_1 || use_fully_connected) { |
| 108 | + std::vector<uint32_t> shape_2d; |
| 109 | + QnnQuantParamsWrapper quant_param_2d = input_info_1.quant_param.Copy(); |
| 110 | + if (reshape_input_1) { |
| 111 | + // Input is 1D tensor. |
| 112 | + input_1_name = org_input_1_name + "_ort_qnn_ep_reshape"; |
| 113 | + if (use_fully_connected) { |
| 114 | + // FullyConnected requires input_1's shape to be [n, k]. |
| 115 | + shape_2d = {1, input_info_1.shape[0]}; |
| 116 | + } else { |
| 117 | + shape_2d = {input_info_1.shape[0], 1}; |
| 118 | + } |
| 119 | + ORT_RETURN_IF_ERROR(quant_param_2d.HandleUnsqueeze<uint32_t>(input_info_1.shape, shape_2d)); |
| 120 | + } else { |
| 121 | + input_1_name = org_input_1_name + "_ort_qnn_ep_transpose"; |
| 122 | + shape_2d = {input_info_1.shape[1], input_info_1.shape[0]}; |
| 123 | + ORT_RETURN_IF_ERROR(quant_param_2d.HandleTranspose<uint32_t>(std::vector<uint32_t>({1, 0}))); |
| 124 | + } |
| 125 | + |
| 126 | + // If input_1 is initializer, unpack it and add the tensor with new quantization parameter and shape. |
| 127 | + // Otherwise, add a Reshape node. |
| 128 | + if (input_info_1.is_initializer) { |
| 129 | + std::vector<uint8_t> unpacked_tensor; |
| 130 | + if (use_fully_connected && !reshape_input_1) { |
| 131 | + // 2D initializer should be transposed to [n, k]. |
| 132 | + ORT_RETURN_IF_ERROR(TwoDimensionTranspose(qnn_model_wrapper, input_info_1.shape, |
| 133 | + *input_info_1.initializer_tensor, unpacked_tensor)); |
| 134 | + } else { |
| 135 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_info_1.initializer_tensor, unpacked_tensor)); |
| 136 | + } |
| 137 | + |
| 138 | + Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType(org_input_1_name); |
| 139 | + QnnTensorWrapper input_tensorwrapper(input_1_name, tensor_type, input_info_1.qnn_data_type, |
| 140 | + std::move(quant_param_2d), std::move(shape_2d), std::move(unpacked_tensor)); |
| 141 | + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); |
| 142 | + } else { |
| 143 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(org_input_1_name, input_1_name, input_info_1.shape, shape_2d, |
| 144 | + input_info_1.qnn_data_type, input_info_1.quant_param, |
| 145 | + quant_param_2d, do_op_validation, |
| 146 | + qnn_model_wrapper.IsGraphInput(org_input_1_name), false)); |
| 147 | + } |
| 148 | + } else { |
| 149 | + if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_1_name)) { |
| 150 | + LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_1_name; |
| 151 | + } else { |
| 152 | + QnnTensorWrapper input_1_tensor; |
| 153 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[1], input_1_tensor)); |
| 154 | + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_1_tensor)), "Failed to add tensor."); |
| 155 | + } |
| 156 | + } |
| 157 | + input_names.emplace_back(input_1_name); |
| 158 | + |
| 159 | + return Status::OK(); |
| 160 | +} |
| 161 | + |
| 162 | +Status MatMulOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, |
| 163 | + std::vector<std::string>&& input_names, |
| 164 | + const logging::Logger& /*logger*/, bool do_op_validation) const { |
| 165 | + const auto& inputs = node_unit.Inputs(); |
| 166 | + TensorInfo input_info_0{}; |
| 167 | + TensorInfo input_info_1{}; |
| 168 | + bool use_fully_connected = false; |
| 169 | + ORT_RETURN_IF_ERROR( |
| 170 | + CheckInputs(qnn_model_wrapper, inputs[0], inputs[1], input_info_0, input_info_1, use_fully_connected)); |
| 171 | + bool reshape_input_0 = input_info_0.shape.size() == 1; |
| 172 | + bool reshape_input_1 = input_info_1.shape.size() == 1; |
| 173 | + bool reshape_output = reshape_input_0 || reshape_input_1 || (use_fully_connected && input_info_0.shape.size() > 2); |
| 174 | + |
| 175 | + const std::string& org_output_name = node_unit.Outputs()[0].node_arg.Name(); |
| 176 | + std::string op_output_name = org_output_name; |
| 177 | + TensorInfo output_info{}; |
| 178 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info)); |
| 179 | + std::vector<uint32_t> op_output_shape = output_info.shape; |
| 180 | + QnnQuantParamsWrapper op_output_quant_param = output_info.quant_param.Copy(); |
| 181 | + if (reshape_output) { |
| 182 | + op_output_name = org_output_name + "_ort_qnn_ep_reshape"; |
| 183 | + if (use_fully_connected && input_info_0.shape.size() > 2) { |
| 184 | + op_output_shape = {std::accumulate(input_info_0.shape.begin(), input_info_0.shape.end() - 1, |
| 185 | + static_cast<uint32_t>(1), std::multiplies<uint32_t>()), |
| 186 | + reshape_input_1 ? 1 : input_info_1.shape.back()}; |
| 187 | + ORT_ENFORCE(!op_output_quant_param.IsPerChannel()); |
| 188 | + } else { |
| 189 | + // If both inputs are 1D tensors, the output shape is [1] instead of scalar. So if both inputs are 1D tensors, |
| 190 | + // we only need to add one "1" to the op_output_shape. |
| 191 | + if (reshape_input_1) { |
| 192 | + op_output_shape.emplace_back(1); |
| 193 | + } else if (reshape_input_0) { |
| 194 | + op_output_shape.insert(op_output_shape.end() - 1, 1); |
| 195 | + } |
| 196 | + ORT_RETURN_IF_ERROR(op_output_quant_param.HandleUnsqueeze<uint32_t>(output_info.shape, op_output_shape)); |
| 197 | + } |
| 198 | + } |
| 199 | + |
| 200 | + const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(org_output_name); |
| 201 | + const bool is_op_output_graph_output = is_graph_output && !reshape_output; |
| 202 | + Qnn_TensorType_t op_output_tensor_type = |
| 203 | + is_op_output_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; |
| 204 | + QnnTensorWrapper op_output_tensor_wrapper(op_output_name, op_output_tensor_type, output_info.qnn_data_type, |
| 205 | + op_output_quant_param.Copy(), std::vector<uint32_t>(op_output_shape)); |
| 206 | + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(op_output_tensor_wrapper)), |
| 207 | + "Failed to add output tensor."); |
| 208 | + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, |
| 209 | + use_fully_connected ? QNN_OP_FULLY_CONNECTED : QNN_OP_MAT_MUL, |
| 210 | + std::move(input_names), {op_output_name}, {}, do_op_validation), |
| 211 | + "Failed to add fused Matmul node."); |
| 212 | + |
| 213 | + if (reshape_output) { |
| 214 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode( |
| 215 | + op_output_name, org_output_name, op_output_shape, output_info.shape, output_info.qnn_data_type, |
| 216 | + op_output_quant_param, output_info.quant_param, do_op_validation, false, is_graph_output)); |
| 217 | + } |
| 218 | + |
| 219 | + return Status::OK(); |
| 220 | +} |
| 221 | + |
| 222 | +void CreateMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { |
| 223 | + op_registrations.AddOpBuilder(op_type, std::make_unique<MatMulOpBuilder>()); |
| 224 | +} |
| 225 | + |
| 226 | +} // namespace qnn |
| 227 | +} // namespace onnxruntime |
0 commit comments