Skip to content

Commit 1ccb164

Browse files
authored
Improve the script to add Q, DQ nodes around EPContext node (microsoft#20107)
Improve the script to add Q, DQ nodes around EPContext node so that the wrapper model use float data as inputs and outputs. User don't need to quantize or dequantize the data in their application
1 parent c529e05 commit 1ccb164

File tree

1 file changed

+78
-7
lines changed

1 file changed

+78
-7
lines changed

onnxruntime/python/tools/qnn/gen_qnn_ctx_onnx_model.py

+78-7
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,17 @@ class QnnTensorStruct:
1414
def __init__(self):
1515
self.name = ""
1616
self.onnx_data_type = TensorProto.FLOAT
17+
self.is_quantized = False
18+
self.scale = 0.0
19+
self.offset = 0
1720
self.dim = []
1821

1922

23+
def is_quantized_data_type(qnn_data_type):
24+
# QNN_DATATYPE_UFIXED_POINT_8 QNN_DATATYPE_UFIXED_POINT_16 QNN_DATATYPE_FIXED_POINT_8 QNN_DATATYPE_FIXED_POINT_16
25+
return qnn_data_type == 0x0408 or qnn_data_type == 0x0416 or qnn_data_type == 0x0308 or qnn_data_type == 0x0316
26+
27+
2028
def qnn_data_type_to_onnx_data_type(qnn_data_type):
2129
# QNN_DATATYPE_UFIXED_POINT_8 QNN_DATATYPE_UINT_8
2230
if qnn_data_type == 0x0408 or qnn_data_type == 0x0108:
@@ -73,15 +81,29 @@ def parse_qnn_json_file(qnn_json_file_path, qnn_input_tensor_dic, qnn_output_ten
7381
qnn_tensor = QnnTensorStruct()
7482
qnn_tensor.name = qnn_tensor_name
7583
qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type(qnn_tensor_attribute["data_type"])
84+
qnn_tensor.is_quantized = is_quantized_data_type(qnn_tensor_attribute["data_type"])
7685
qnn_tensor.dim = qnn_tensor_attribute["dims"]
86+
if (
87+
qnn_tensor_attribute["quant_params"]["definition"] == 1
88+
and qnn_tensor_attribute["quant_params"]["encoding"] == 0
89+
):
90+
qnn_tensor.scale = qnn_tensor_attribute["quant_params"]["scale_offset"]["scale"]
91+
qnn_tensor.offset = 0 - qnn_tensor_attribute["quant_params"]["scale_offset"]["offset"]
7792
qnn_input_tensor_dic[qnn_tensor_name] = qnn_tensor
7893

7994
# Get all graph outputs
8095
if qnn_tensor_attribute["type"] == 1:
8196
qnn_tensor = QnnTensorStruct()
8297
qnn_tensor.name = qnn_tensor_name
8398
qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type(qnn_tensor_attribute["data_type"])
99+
qnn_tensor.is_quantized = is_quantized_data_type(qnn_tensor_attribute["data_type"])
84100
qnn_tensor.dim = qnn_tensor_attribute["dims"]
101+
if (
102+
qnn_tensor_attribute["quant_params"]["definition"] == 1
103+
and qnn_tensor_attribute["quant_params"]["encoding"] == 0
104+
):
105+
qnn_tensor.scale = qnn_tensor_attribute["quant_params"]["scale_offset"]["scale"]
106+
qnn_tensor.offset = 0 - qnn_tensor_attribute["quant_params"]["scale_offset"]["offset"]
85107
qnn_output_tensor_dic[qnn_tensor_name] = qnn_tensor
86108

87109
assert (
@@ -120,13 +142,33 @@ def main():
120142
ep_cache_context_content = file.read()
121143
ctx_embed_mode = 1
122144

123-
qnn_inputs = []
124-
for qnn_input in qnn_input_tensor_dic.values():
125-
qnn_inputs.append(helper.make_tensor_value_info(qnn_input.name, qnn_input.onnx_data_type, qnn_input.dim))
145+
graph_nodes = []
146+
ini_list = []
147+
value_infos = []
126148

127-
qnn_outputs = []
128-
for qnn_output in qnn_output_tensor_dic.values():
129-
qnn_outputs.append(helper.make_tensor_value_info(qnn_output.name, qnn_output.onnx_data_type, qnn_output.dim))
149+
model_inputs = []
150+
for qnn_input in qnn_input_tensor_dic.values():
151+
if qnn_input.is_quantized:
152+
q_scale_input_name = qnn_input.name + "_scale"
153+
q_offset_input_name = qnn_input.name + "_zp"
154+
q_scale = helper.make_tensor(q_scale_input_name, TensorProto.FLOAT, [], [qnn_input.scale])
155+
ini_list.append(q_scale)
156+
q_offset = helper.make_tensor(q_offset_input_name, qnn_input.onnx_data_type, [], [qnn_input.offset])
157+
ini_list.append(q_offset)
158+
input_name = qnn_input.name + "_dq"
159+
160+
q_node = helper.make_node(
161+
"QuantizeLinear",
162+
name=qnn_input.name,
163+
inputs=[input_name, q_scale_input_name, q_offset_input_name],
164+
outputs=[qnn_input.name],
165+
)
166+
167+
graph_nodes.append(q_node)
168+
model_inputs.append(helper.make_tensor_value_info(input_name, TensorProto.FLOAT, qnn_input.dim))
169+
value_infos.append(helper.make_tensor_value_info(qnn_input.name, qnn_input.onnx_data_type, qnn_input.dim))
170+
else:
171+
model_inputs.append(helper.make_tensor_value_info(qnn_input.name, qnn_input.onnx_data_type, qnn_input.dim))
130172

131173
qnn_ep_context_node = helper.make_node(
132174
"EPContext",
@@ -138,8 +180,37 @@ def main():
138180
source="Qnn",
139181
domain="com.microsoft",
140182
)
183+
graph_nodes.append(qnn_ep_context_node)
141184

142-
graph_def = helper.make_graph([qnn_ep_context_node], "qnn-onnx-model", qnn_inputs, qnn_outputs)
185+
model_outputs = []
186+
for qnn_output in qnn_output_tensor_dic.values():
187+
if qnn_output.is_quantized:
188+
dq_scale_input_name = qnn_output.name + "_scale"
189+
dq_offset_input_name = qnn_output.name + "_zp"
190+
dq_scale = helper.make_tensor(dq_scale_input_name, TensorProto.FLOAT, [], [qnn_output.scale])
191+
ini_list.append(dq_scale)
192+
dq_offset = helper.make_tensor(dq_offset_input_name, qnn_output.onnx_data_type, [], [qnn_output.offset])
193+
ini_list.append(dq_offset)
194+
output_name = qnn_output.name + "_dq"
195+
196+
dq_node = helper.make_node(
197+
"DequantizeLinear",
198+
name=output_name,
199+
inputs=[qnn_output.name, dq_scale_input_name, dq_offset_input_name],
200+
outputs=[output_name],
201+
)
202+
203+
graph_nodes.append(dq_node)
204+
model_outputs.append(helper.make_tensor_value_info(output_name, TensorProto.FLOAT, qnn_output.dim))
205+
value_infos.append(
206+
helper.make_tensor_value_info(qnn_output.name, qnn_output.onnx_data_type, qnn_output.dim)
207+
)
208+
else:
209+
model_outputs.append(
210+
helper.make_tensor_value_info(qnn_output.name, qnn_output.onnx_data_type, qnn_output.dim)
211+
)
212+
213+
graph_def = helper.make_graph(graph_nodes, "qnn-onnx-model", model_inputs, model_outputs, ini_list, "", value_infos)
143214

144215
model_def = helper.make_model(graph_def, producer_name="MS")
145216

0 commit comments

Comments
 (0)