-
Notifications
You must be signed in to change notification settings - Fork 438
/
Copy pathquantize.py
91 lines (73 loc) · 3.34 KB
/
quantize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# SPDX-License-Identifier: Apache-2.0
"""
tensor
"""
import logging
import numpy as np
from onnx.onnx_pb import TensorProto
from tf2onnx import utils
from tf2onnx.handler import tf_op
from tf2onnx.utils import make_sure
logger = logging.getLogger(__name__)
# pylint: disable=unused-argument,missing-docstring,unused-variable,pointless-string-statement,invalid-name
@tf_op(["FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxVars"])
class FakeQuantWithMinMaxArgs:
# see https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/fake-quant-with-min-max-args
@classmethod
def version_10(cls, ctx, node, **kwargs):
# hack to make up for the missing onnx pack op
if node.type == "FakeQuantWithMinMaxVars":
utils.make_sure(node.inputs[1].is_scalar(), "%s node %s requires const scalar value for min",
node.type, node.name)
utils.make_sure(node.inputs[2].is_scalar(), "%s node %s requires const scalar value for max",
node.type, node.name)
amin = node.inputs[1].get_tensor_value()
amax = node.inputs[2].get_tensor_value()
else:
amin = node.get_attr("min").f
amax = node.get_attr("max").f
narrow_range = node.get_attr("narrow_range").i
num_bits = node.get_attr("num_bits").i
logger.debug(
"Convert node FakeQuantWithMinMaxArgs with narrow_range=%r",
narrow_range)
make_sure(
num_bits == 8,
"Unable to convert node FakeQuantWithMinMaxArgs with "
"num_bits=%r", num_bits)
# Allow narrow_range since TensorRT requires quantized range to be (-127, 127)
if narrow_range:
scale = amax / (2**(num_bits-1)-1)
idtype = TensorProto.INT8
zero = np.zeros(np.array(amin).shape, dtype=np.int8)
else:
scale = (amax - amin) / (2 ** num_bits - 1)
min_adj = np.around(amin / scale)
idtype = TensorProto.UINT8
zero = np.array(-min_adj, dtype=np.uint8)
make_sure(
zero == -min_adj,
"Cannot convert %s node %s with "
"min=%r max=%r numbits=%r because zero_scale=%r "
"is outside uint8 boundary",
node.type, node.name, amin, amax, num_bits, -min_adj)
dtype = ctx.get_dtype(node.input[0])
shape = ctx.get_shape(node.input[0])
axis = 1
pb_scale = ctx.make_const(
utils.make_name("{}_scaley".format(node.name)),
np.array(scale, dtype=np.float32))
zero_point = ctx.make_const(
utils.make_name("{}_zpy".format(node.name)), zero)
new_node = ctx.make_node(
"QuantizeLinear", [node.input[0], pb_scale.name, zero_point.name],
op_name_scope=node.name, attr={"axis": axis},
shapes=[shape], dtypes=[idtype])
output_name = new_node.output[0]
ctx.replace_input(node, node.input[0], output_name, 0)
ctx.remove_node(node.name)
last_node = ctx.make_node(
"DequantizeLinear", [new_node.output[0], pb_scale.name, zero_point.name],
op_name_scope=node.name, attr={"axis": axis},
shapes=[shape], dtypes=[dtype])
ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()