Skip to content

Commit 1bfb963

Browse files
adrianlizarragaankitm3k
authored andcommitted
[Quant Tool] Prevent int32 quantized bias from clipping by adjusting the weight's scale (microsoft#22020)
### Description Fixes scenario in which a bias input quantized to int32 has a scale that is too small. A bias with a scale that is smaller than a certain threshold will overflow the range of an `int32` when quantized, which significantly decreases accuracy. Credit to @yihonglyu for finding out about this issue and the fix. ### Motivation and Context Consider the following Convolution with very small weights and a constant bias input of `[5, -4.5]`. ![image](https://github.com/user-attachments/assets/4bde2bd9-892f-4ae9-887b-61a6668779a1) The QDQ quantizer first computes the following quantization scale for `input_0` and `weight`: - `input_0`: scale=0.5 - `weight`: scale=7.843e-10 **[really small]** The QDQ quantizer then computes the bias input's scale as follows: ``` bias_scale = input_0_scale * weight_0_scale = 0.5 * 7.843e-10 = 3.9215686274509805e-11 ``` This `bias_scale` is too small. Before this PR, the QDQ quantizer would quantize the f32 bias with this `bias_scale`: ``` bias_quant = round(bias_f32 / bias_scale) = round([5.0/bias_scale, -4.5/bias_scale]) = [127500000000, -114750000000] ``` These quantized bias values exceed the range of int32, and so are clipped to [int32.min(), int32.max()], which is very inaccurate. #### New approach This PR increases the `weight_0_scale` by the necessary amount to ensure that `bias_scale` (which equals `weight_0_scale * input_0_scale`) is appropriate for the int32 quantization type. The smallest valid bias scale is given by the normal scale formula: `bias_smallest_valid_scale = (bias_f32_max - bias_f32_min) / (int32_max - int32_min)` Then, we compute the candidate bias scale: `bias_scale_candidate = input_0_scale * weight_0_scale` If the candidate scale is smaller than the smallest valid scale, we increase the `weight_0_scale` by the necessary ratio: ```python if bias_scale_candidate < bias_smallest_valid_scale: ratio = bias_smallest_valid_scale / bias_scale_candidate weight_0_scale = ratio * weight_0_scale ``` Then, we recompute the final bias scale: ```python bias_scale = input_0_scale * weight_0_scale ``` #### Impact on accuracy Here's the above model's quantized output compared to the f32 (ground-truth) output. - Before PR: - f32 model output[0]: **5.0f** - qdq model output[0]: **0.075** - SNR: 0.1369 (higher is better) - After PR: - f32 model output[0]: **5.0f** - qdq model output[0]: **4.992** - SNR: 55.656 (higher is better)
1 parent 01ecbb0 commit 1bfb963

File tree

2 files changed

+10
-313
lines changed

2 files changed

+10
-313
lines changed

onnxruntime/python/tools/quantization/qdq_quantizer.py

+10-38
Original file line numberDiff line numberDiff line change
@@ -1168,30 +1168,6 @@ def is_tensor_per_channel(
11681168

11691169
return True, axis
11701170

1171-
def _get_tensor_quantization_scale(self, tensor_name: str, consumer_node_name: str) -> np.ndarray | None:
1172-
"""
1173-
Returns the quantization scale of a tensor that is consumed by the given node.
1174-
:parameter tensor_name: The name of the tensor.
1175-
:parameter consumer_node_name: The name of the node that consumes the tensor as input. Necessary in case
1176-
the quantization type of the tensor was converted.
1177-
Refer: QDQQuantizer::_add_qdq_ops_for_converted_activation.
1178-
:returns: The quantization scale or None.
1179-
"""
1180-
initializers = self.model.initializer()
1181-
scale_initializer: onnx.TensorProto | None = None
1182-
1183-
if tensor_name in self.quantized_value_map:
1184-
# Tensor was quantized by this tool, so get scale from initializer created by this tool run.
1185-
scale_name = self.quantized_value_map[tensor_name].get_for_consumer(consumer_node_name).scale_name
1186-
scale_initializer = find_by_name(scale_name, initializers)
1187-
else:
1188-
# Tensor was already quantized in original model, so get scale from DQ node that outputs the tensor.
1189-
dq_node = self.tensor_to_producing_dq.get(tensor_name, None)
1190-
if dq_node:
1191-
scale_initializer = find_by_name(dq_node.input[1], initializers)
1192-
1193-
return tensor_proto_to_array(scale_initializer) if scale_initializer is not None else None
1194-
11951171
def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> str:
11961172
"""
11971173
Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale
@@ -1201,21 +1177,17 @@ def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> s
12011177
if bias_name in self.quantized_value_map:
12021178
return self.quantized_value_map[bias_name].original.q_name
12031179

1204-
# get scale for weight.
1205-
weight_scale = self._get_tensor_quantization_scale(bias_info.weight_name, bias_info.node_name)
1206-
if weight_scale is None:
1207-
raise ValueError(
1208-
f"Unable to get valid quantization scale for weight input '{bias_info.weight_name}' "
1209-
f"when quantizing bias '{bias_name}' to int32."
1210-
)
1180+
# get scale for weight
1181+
weight_scale_name = self.quantized_value_map[bias_info.weight_name].original.scale_name
1182+
weight_scale_initializer = find_by_name(weight_scale_name, self.model.initializer())
1183+
weight_scale = tensor_proto_to_array(weight_scale_initializer)
12111184

1212-
# get scale for input.
1213-
input_scale = self._get_tensor_quantization_scale(bias_info.input_name, bias_info.node_name)
1214-
if input_scale is None:
1215-
raise ValueError(
1216-
f"Unable to get valid quantization scale for input '{bias_info.input_name}' "
1217-
f"when quantizing bias '{bias_name}' to int32."
1218-
)
1185+
# get scale for input
1186+
input_scale_name = (
1187+
self.quantized_value_map[bias_info.input_name].get_for_consumer(bias_info.node_name).scale_name
1188+
)
1189+
input_scale_initializer = find_by_name(input_scale_name, self.model.initializer())
1190+
input_scale = tensor_proto_to_array(input_scale_initializer)
12191191

12201192
(
12211193
quantized_bias_name,

onnxruntime/test/python/quantization/test_qdq.py

-275
Original file line numberDiff line numberDiff line change
@@ -1927,280 +1927,5 @@ def test_dup_shared_bias(self):
19271927
self.assertEqual(len(bias_names), 2)
19281928

19291929

1930-
class TestQDQPrequantWeights(unittest.TestCase):
1931-
@classmethod
1932-
def setUpClass(cls):
1933-
cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.prequant_weight")
1934-
1935-
# Note: swap with the commented line if you want to see the models in local test dir.
1936-
cls._tmp_dir_path = cls._tmp_model_dir.name
1937-
# cls._tmp_dir_path = "."
1938-
1939-
@classmethod
1940-
def tearDownClass(cls):
1941-
cls._tmp_model_dir.cleanup()
1942-
1943-
def build_conv_model(
1944-
self,
1945-
inp_shape: list[int],
1946-
weight_quant_data: np.ndarray,
1947-
weight_scale_data: np.ndarray,
1948-
weight_zp_data: np.ndarray,
1949-
bias_data: np.ndarray,
1950-
float_type: onnx.TensorProto.DataType = onnx.TensorProto.FLOAT,
1951-
):
1952-
"""
1953-
Builds a model with a Conv that has a pre-quantized constant weight input.
1954-
"""
1955-
input_0 = onnx.helper.make_tensor_value_info("input_0", float_type, inp_shape)
1956-
output_0 = onnx.helper.make_tensor_value_info("output_0", float_type, None)
1957-
weight_quant = onnx.numpy_helper.from_array(weight_quant_data, "weight_quant")
1958-
weight_scale = onnx.numpy_helper.from_array(weight_scale_data, "weight_scale")
1959-
weight_zp = onnx.numpy_helper.from_array(weight_zp_data, "weight_zp")
1960-
bias = onnx.numpy_helper.from_array(bias_data, "bias")
1961-
1962-
dq_node = onnx.helper.make_node(
1963-
"DequantizeLinear", ["weight_quant", "weight_scale", "weight_zp"], ["weight_dequant"], name="DQ0"
1964-
)
1965-
conv_node = onnx.helper.make_node("Conv", ["input_0", "weight_dequant", "bias"], ["output_0"], name="Conv0")
1966-
graph = onnx.helper.make_graph(
1967-
[dq_node, conv_node],
1968-
"ConvPreQuantWeight",
1969-
[input_0],
1970-
[output_0],
1971-
initializer=[weight_quant, weight_scale, weight_zp, bias],
1972-
)
1973-
opset_imports = [onnx.helper.make_opsetid("", 21)]
1974-
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
1975-
1976-
return onnx.shape_inference.infer_shapes(model)
1977-
1978-
def build_conv_dynamic_weight_model(
1979-
self,
1980-
input_quant_data: np.ndarray,
1981-
input_scale_data: np.ndarray,
1982-
input_zp_data: np.ndarray,
1983-
weight_shape: list[int],
1984-
bias_data: np.ndarray,
1985-
float_type: onnx.TensorProto.DataType = onnx.TensorProto.FLOAT,
1986-
):
1987-
"""
1988-
Builds a model with a Conv that has a dynamic float weight input, but a constant
1989-
pre-quantized input[0].
1990-
"""
1991-
dyn_weight = onnx.helper.make_tensor_value_info("dyn_weight", float_type, weight_shape)
1992-
output_0 = onnx.helper.make_tensor_value_info("output_0", float_type, None)
1993-
input_quant = onnx.numpy_helper.from_array(input_quant_data, "input_quant")
1994-
input_scale = onnx.numpy_helper.from_array(input_scale_data, "input_scale")
1995-
input_zp = onnx.numpy_helper.from_array(input_zp_data, "input_zp")
1996-
bias = onnx.numpy_helper.from_array(bias_data, "bias")
1997-
1998-
dq_node = onnx.helper.make_node(
1999-
"DequantizeLinear", ["input_quant", "input_scale", "input_zp"], ["input_dequant"], name="DQ0"
2000-
)
2001-
conv_node = onnx.helper.make_node("Conv", ["input_dequant", "dyn_weight", "bias"], ["output_0"], name="Conv0")
2002-
graph = onnx.helper.make_graph(
2003-
[dq_node, conv_node],
2004-
"ConvPreQuantInput_DynamicWeight",
2005-
[dyn_weight],
2006-
[output_0],
2007-
initializer=[input_quant, input_scale, input_zp, bias],
2008-
)
2009-
opset_imports = [onnx.helper.make_opsetid("", 21)]
2010-
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
2011-
2012-
return onnx.shape_inference.infer_shapes(model)
2013-
2014-
def test_quantize_with_prequantized_weights(self):
2015-
"""
2016-
Test quantization of Conv with pre-quantized weights.
2017-
"""
2018-
rng = np.random.default_rng(123)
2019-
test_configs = [onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT16]
2020-
2021-
for float_type in test_configs:
2022-
with self.subTest(float_type=float_type):
2023-
label = f"_{onnx.TensorProto.DataType.Name(float_type)}"
2024-
float_model_path = os.path.join(self._tmp_dir_path, f"conv.f32.prequant_weight{label}.onnx")
2025-
qdq_model_path = os.path.join(self._tmp_dir_path, f"conv.prequant_weight{label}.qdq.onnx")
2026-
2027-
inp_shape = [1, 2, 100, 100]
2028-
weight_shape = [2, 2, 20, 20]
2029-
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(float_type)
2030-
2031-
# range = 2.0, scale = 2/254, zp = 0
2032-
weight_scale_data = np.array(2 / 254, dtype=np_dtype)
2033-
weight_zp_data = np.array(0, dtype=np.int8)
2034-
weight_data = np.linspace(-1.0, 1.0, num=1600, dtype=np_dtype).reshape(weight_shape)
2035-
weight_quant_data = quantize_nparray(
2036-
onnx.TensorProto.INT8, weight_data, weight_scale_data, weight_zp_data
2037-
)
2038-
2039-
bias_data = np.array([-10.0, 10.0], dtype=np_dtype)
2040-
float_model = self.build_conv_model(
2041-
inp_shape, weight_quant_data, weight_scale_data, weight_zp_data, bias_data, float_type
2042-
)
2043-
2044-
onnx.checker.check_model(float_model, True)
2045-
onnx.save_model(float_model, float_model_path)
2046-
2047-
# Check that the input model only has a pre-quantized weight and save its scale/zero-point
2048-
# to check that it doesn't change after quantization.
2049-
float_node_counts = {"QuantizeLinear": 0, "DequantizeLinear": 1}
2050-
check_op_type_count(self, float_model_path, **float_node_counts)
2051-
conv_node_original = next((node for node in float_model.graph.node if node.op_type == "Conv"), None)
2052-
self.assertNotEqual(conv_node_original, None)
2053-
2054-
_, producers_original = get_tensor_consumers_and_producers(float_model)
2055-
weight_dq_node_original = producers_original.get(conv_node_original.input[1], None)
2056-
initializers_original = {initializer.name: initializer for initializer in float_model.graph.initializer}
2057-
scale_name_original = weight_dq_node_original.input[1]
2058-
scale_val_original = onnx.numpy_helper.to_array(initializers_original[scale_name_original])
2059-
zp_name_original = weight_dq_node_original.input[2]
2060-
zp_val_original = onnx.numpy_helper.to_array(initializers_original[zp_name_original])
2061-
2062-
input_data_list = [
2063-
{"input_0": rng.uniform(-10.0, 10.0, inp_shape).astype(np_dtype)},
2064-
]
2065-
data_reader = TestDataFeeds(input_data_list)
2066-
2067-
quantize_static(
2068-
float_model_path,
2069-
qdq_model_path,
2070-
data_reader,
2071-
quant_format=QuantFormat.QDQ,
2072-
activation_type=QuantType.QUInt8,
2073-
weight_type=QuantType.QInt8,
2074-
op_types_to_quantize=["Conv"],
2075-
)
2076-
2077-
# The final model should have everything quantized
2078-
qdq_node_counts = {"QuantizeLinear": 2, "DequantizeLinear": 4}
2079-
check_op_type_count(self, qdq_model_path, **qdq_node_counts)
2080-
2081-
# Check that the pre-quantized weight still has the same scale/zp after quantization
2082-
qdq_model = onnx.load_model(qdq_model_path)
2083-
conv_node = next((node for node in qdq_model.graph.node if node.op_type == "Conv"), None)
2084-
self.assertNotEqual(conv_node, None)
2085-
2086-
_, producers = get_tensor_consumers_and_producers(qdq_model)
2087-
weight_dq_node = producers.get(conv_node.input[1], None)
2088-
initializers = {initializer.name: initializer for initializer in qdq_model.graph.initializer}
2089-
2090-
scale_name = weight_dq_node.input[1]
2091-
self.assertEqual(scale_name, scale_name_original)
2092-
scale_val = onnx.numpy_helper.to_array(initializers[scale_name])
2093-
self.assertEqual(scale_val, scale_val_original)
2094-
2095-
zp_name = weight_dq_node.input[2]
2096-
self.assertEqual(zp_name, zp_name_original)
2097-
zp_val = onnx.numpy_helper.to_array(initializers[zp_name])
2098-
self.assertEqual(zp_val, zp_val_original)
2099-
2100-
def test_quantize_with_prequantized_input(self):
2101-
"""
2102-
Test quantization of Conv with pre-quantized input and dynamic weight.
2103-
"""
2104-
rng = np.random.default_rng(123)
2105-
test_configs = [
2106-
(onnx.TensorProto.FLOAT, False),
2107-
(onnx.TensorProto.FLOAT16, False),
2108-
(onnx.TensorProto.FLOAT, True),
2109-
(onnx.TensorProto.FLOAT16, True),
2110-
]
2111-
2112-
for float_type, convert_weight_qtype in test_configs:
2113-
with self.subTest(float_type=float_type):
2114-
convert_label = "_convert_qtype" if convert_weight_qtype else ""
2115-
label = f"_{onnx.TensorProto.DataType.Name(float_type)}{convert_label}"
2116-
float_model_path = os.path.join(self._tmp_dir_path, f"conv.f32.prequant_input{label}.onnx")
2117-
qdq_model_path = os.path.join(self._tmp_dir_path, f"conv.prequant_input{label}.qdq.onnx")
2118-
2119-
inp_shape = [1, 2, 40, 40]
2120-
weight_shape = [2, 2, 20, 20]
2121-
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(float_type)
2122-
2123-
# range = 3.0, scale = 3/255, zp = 127
2124-
input_scale_data = np.array(3 / 255, dtype=np_dtype)
2125-
input_zp_data = np.array(127, dtype=np.uint8)
2126-
input_data = np.linspace(-1.5, 1.5, num=3200, dtype=np_dtype).reshape(inp_shape)
2127-
input_quant_data = quantize_nparray(onnx.TensorProto.UINT8, input_data, input_scale_data, input_zp_data)
2128-
2129-
bias_data = np.array([-10.0, 10.0], dtype=np_dtype)
2130-
float_model = self.build_conv_dynamic_weight_model(
2131-
input_quant_data, input_scale_data, input_zp_data, weight_shape, bias_data, float_type
2132-
)
2133-
2134-
onnx.checker.check_model(float_model, True)
2135-
onnx.save_model(float_model, float_model_path)
2136-
2137-
# Check that the input model only has a pre-quantized input and save its scale/zero-point
2138-
# to check that it doesn't change after quantization.
2139-
float_node_counts = {"QuantizeLinear": 0, "DequantizeLinear": 1}
2140-
check_op_type_count(self, float_model_path, **float_node_counts)
2141-
conv_node_original = next((node for node in float_model.graph.node if node.op_type == "Conv"), None)
2142-
self.assertNotEqual(conv_node_original, None)
2143-
2144-
_, producers_original = get_tensor_consumers_and_producers(float_model)
2145-
input_dq_node_original = producers_original.get(conv_node_original.input[0], None)
2146-
initializers_original = {initializer.name: initializer for initializer in float_model.graph.initializer}
2147-
scale_name_original = input_dq_node_original.input[1]
2148-
scale_val_original = onnx.numpy_helper.to_array(initializers_original[scale_name_original])
2149-
zp_name_original = input_dq_node_original.input[2]
2150-
zp_val_original = onnx.numpy_helper.to_array(initializers_original[zp_name_original])
2151-
2152-
# Create data reader with random input calibration data.
2153-
dyn_weight_data_list = [
2154-
{"dyn_weight": rng.uniform(-10.0, 10.0, weight_shape).astype(np_dtype)},
2155-
]
2156-
data_reader = TestDataFeeds(dyn_weight_data_list)
2157-
2158-
extra_options = {}
2159-
if convert_weight_qtype:
2160-
# Test converting the dynamic weight's quantization type, which results in
2161-
# dyn_weight -> Q(u16) -> DQ(f32) -> Q(u8) -> DQ(f32) -> Conv
2162-
extra_options["TensorQuantOverrides"] = {
2163-
"dyn_weight": [{"quant_type": QuantType.QUInt16, "convert": {"quant_type": QuantType.QUInt8}}],
2164-
}
2165-
2166-
quantize_static(
2167-
float_model_path,
2168-
qdq_model_path,
2169-
data_reader,
2170-
quant_format=QuantFormat.QDQ,
2171-
activation_type=QuantType.QUInt8,
2172-
weight_type=QuantType.QInt8,
2173-
op_types_to_quantize=["Conv"],
2174-
extra_options=extra_options,
2175-
)
2176-
2177-
# The final model should have everything quantized
2178-
qdq_node_counts = {"QuantizeLinear": 2, "DequantizeLinear": 4}
2179-
if convert_weight_qtype:
2180-
qdq_node_counts["QuantizeLinear"] += 1
2181-
qdq_node_counts["DequantizeLinear"] += 1
2182-
2183-
check_op_type_count(self, qdq_model_path, **qdq_node_counts)
2184-
2185-
# Check that the pre-quantized input still has the same scale/zp after quantization
2186-
qdq_model = onnx.load_model(qdq_model_path)
2187-
conv_node = next((node for node in qdq_model.graph.node if node.op_type == "Conv"), None)
2188-
self.assertNotEqual(conv_node, None)
2189-
2190-
_, producers = get_tensor_consumers_and_producers(qdq_model)
2191-
input_dq_node = producers.get(conv_node.input[0], None)
2192-
initializers = {initializer.name: initializer for initializer in qdq_model.graph.initializer}
2193-
2194-
scale_name = input_dq_node.input[1]
2195-
self.assertEqual(scale_name, scale_name_original)
2196-
scale_val = onnx.numpy_helper.to_array(initializers[scale_name])
2197-
self.assertEqual(scale_val, scale_val_original)
2198-
2199-
zp_name = input_dq_node.input[2]
2200-
self.assertEqual(zp_name, zp_name_original)
2201-
zp_val = onnx.numpy_helper.to_array(initializers[zp_name])
2202-
self.assertEqual(zp_val, zp_val_original)
2203-
2204-
22051930
if __name__ == "__main__":
22061931
unittest.main()

0 commit comments

Comments
 (0)