Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit d0ba055

Browse files
authored
QAT folding update (#1639)
* Add transformation to propagate dequantize op through split * Remove requirement that QuantizeLinear must be next to DequantizeLinear for input branch of Conv node * Fixed embedding quantization propagation * Quality fixes * Add zero point to dequant node * Add zero point to initializers * Style fixes * Fix data type * Allow MatMul weight to be on either input 0 or 1 * Style fixes * Add padding value * Make initializers distinct * Style and quality fixes * Make bias optional for Conv QAT conversion * Quality fix
1 parent 0617a9e commit d0ba055

File tree

7 files changed

+230
-18
lines changed

7 files changed

+230
-18
lines changed

src/sparseml/exporters/onnx_to_deepsparse.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(
7575
sparseml_transforms.DeleteRepeatedQdq(),
7676
sparseml_transforms.QuantizeQATEmbedding(),
7777
sparseml_transforms.PropagateEmbeddingQuantization(),
78+
sparseml_transforms.PropagateDequantThroughSplit(),
7879
sparseml_transforms.MatMulToQLinearMatMul(),
7980
sparseml_transforms.MatMulAddToMatMulIntegerAddCastMul(),
8081
sparseml_transforms.MatMulToMatMulIntegerCastMul(),

src/sparseml/exporters/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from .matmul_add_to_matmulinteger_add_cast_mul import MatMulAddToMatMulIntegerAddCastMul
4242
from .matmul_to_matmulinteger_cast_mul import MatMulToMatMulIntegerCastMul
4343
from .propagate_embedding_quantization import PropagateEmbeddingQuantization
44+
from .propagate_dequant_through_split import PropagateDequantThroughSplit
4445
from .quantize_qat_embedding import QuantizeQATEmbedding
4546
from .quantize_residuals import QuantizeResiduals
4647
from .remove_duplicate_qconv_weights import RemoveDuplicateQConvWeights

src/sparseml/exporters/transforms/conv_to_convinteger_add_cast_mul.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,12 @@ class ConvToConvIntegerAddCastMul(OnnxTransform):
6666

6767
def transform(self, model: ModelProto) -> ModelProto:
6868
graph = ONNXGraph(model)
69-
matches = get_structural_matches(
69+
70+
# Nodes with bias
71+
matches_bias = get_structural_matches(
7072
graph,
7173
parent_ops=[
72-
["QuantizeLinear", "DequantizeLinear"],
74+
["DequantizeLinear"],
7375
[
7476
# weight should be initializer
7577
INITIALIZER_MATCH,
@@ -78,20 +80,50 @@ def transform(self, model: ModelProto) -> ModelProto:
7880
],
7981
[
8082
# bias should be initializer
81-
INITIALIZER_MATCH
83+
INITIALIZER_MATCH,
84+
],
85+
],
86+
op_type="Conv",
87+
)
88+
89+
# Nodes without bias
90+
matches_no_bias = get_structural_matches(
91+
graph,
92+
parent_ops=[
93+
["DequantizeLinear"],
94+
[
95+
# weight should be initializer
96+
INITIALIZER_MATCH,
97+
"QuantizeLinear",
98+
"DequantizeLinear",
8299
],
83100
],
84101
op_type="Conv",
85102
)
103+
104+
matches = matches_bias
105+
matches_names = [m.node.name for m in matches]
106+
for match in matches_no_bias:
107+
if match.node.name not in matches_names:
108+
matches.append(match)
109+
86110
for match in matches:
87111
self.log_match(match)
88112
self._transform_match(graph, model, match)
89113
return model
90114

91-
def _transform_match(self, graph: ONNXGraph, model: ModelProto, match: MatchResult):
92-
input_quant, input_dequant = match.parents[0]
115+
def _transform_match(
116+
self,
117+
graph: ONNXGraph,
118+
model: ModelProto,
119+
match: MatchResult,
120+
):
121+
(input_dequant,) = match.parents[0]
93122
weight_init, weight_quantize_node, weight_dequantize_node = match.parents[1]
94-
(bias_init,) = match.parents[2]
123+
if len(match.parents) == 3:
124+
(bias_init,) = match.parents[2]
125+
else:
126+
bias_init = None
95127

96128
model = add_quantized_conv_matmul_add_ops(
97129
model=model,

src/sparseml/exporters/transforms/matmul_add_to_matmulinteger_add_cast_mul.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,37 @@ class MatMulAddToMatMulIntegerAddCastMul(OnnxTransform):
6767

6868
def transform(self, model: ModelProto) -> ModelProto:
6969
graph = ONNXGraph(model)
70+
71+
# Weight on input 0
72+
matches = get_structural_matches(
73+
graph,
74+
op_type="MatMul",
75+
parent_ops=[
76+
[
77+
# weight should be initializer
78+
INITIALIZER_MATCH,
79+
"QuantizeLinear",
80+
"DequantizeLinear",
81+
optional_node("Transpose"),
82+
],
83+
[any_of("QuantizeLinear", "DequantizeLinear")],
84+
],
85+
children_ops=[[optional_node("Add")]],
86+
)
87+
for match in matches:
88+
add_node = match.children[0][0]
89+
bias_init = None
90+
if add_node:
91+
# NOTE: bias could be either input 0 or 1 of add node
92+
# if add does not have a bias initializer,
93+
# still do conversion, but do not fold the bias add to rescale
94+
bias_init = graph.get_init_by_name(match.children[0][0].input[1])
95+
if bias_init is None:
96+
bias_init = graph.get_init_by_name(match.children[0][0].input[0])
97+
self.log_match(match)
98+
self._transform_match(graph, model, match, bias_init, 0)
99+
100+
# Weight on input 1
70101
matches = get_structural_matches(
71102
graph,
72103
op_type="MatMul",
@@ -93,7 +124,8 @@ def transform(self, model: ModelProto) -> ModelProto:
93124
if bias_init is None:
94125
bias_init = graph.get_init_by_name(match.children[0][0].input[0])
95126
self.log_match(match)
96-
self._transform_match(graph, model, match, bias_init)
127+
self._transform_match(graph, model, match, bias_init, 1)
128+
97129
return model
98130

99131
def _transform_match(
@@ -102,10 +134,15 @@ def _transform_match(
102134
model: ModelProto,
103135
match: MatchResult,
104136
bias_init: TensorProto,
137+
weight_parent: int,
105138
):
106139
matmul = match.node
107-
(input_quant,) = match.parents[0]
108-
weight_init, weight_quant, weight_dequant, opt_transpose = match.parents[1]
140+
if weight_parent == 0:
141+
(input_quant,) = match.parents[1]
142+
weight_init, weight_quant, weight_dequant, opt_transpose = match.parents[0]
143+
else:
144+
(input_quant,) = match.parents[0]
145+
weight_init, weight_quant, weight_dequant, opt_transpose = match.parents[1]
109146
(add,) = match.children[0]
110147

111148
input_quantize_params = get_quantization_params(
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import onnx
16+
from onnx import ModelProto
17+
18+
from sparseml.exporters.transforms import OnnxTransform
19+
from sparseml.exporters.transforms.utils import MatchResult, get_structural_matches
20+
from sparseml.onnx.utils import ONNXGraph
21+
22+
23+
__all__ = ["PropagateDequantThroughSplit"]
24+
25+
26+
class PropagateDequantThroughSplit(OnnxTransform):
27+
"""
28+
A pass for propagating DequantizeLinear nodes down through a split node
29+
so if there are quantized operations after the split they can
30+
be properly converted.
31+
Starting with:
32+
| INPUT
33+
| |
34+
| DequantizeLinear
35+
| |
36+
| Split
37+
| | | |
38+
Converts to:
39+
| INPUT
40+
| |
41+
| Split
42+
| | | |
43+
| DequantizeLinear DequantizeLinear DequantizeLinear
44+
| | | |
45+
"""
46+
47+
def transform(self, model: ModelProto) -> ModelProto:
48+
graph = ONNXGraph(model)
49+
matches = get_structural_matches(
50+
graph,
51+
parent_ops=[["DequantizeLinear"]],
52+
op_type="Split",
53+
)
54+
for match in matches:
55+
self.log_match(match)
56+
self._transform_match(model, match)
57+
return model
58+
59+
def _transform_match(self, model: ModelProto, match: MatchResult):
60+
61+
# Loop through the nodes that are children of the Split node
62+
# For every child, create a DequantizeLinear node and insert
63+
# between Split and child
64+
for split_output_id in range(len(match.node.output)):
65+
dequant_node_name = match.node.name + f"_dequant.{split_output_id}"
66+
dequant_node_output = match.node.output[split_output_id]
67+
dequant_node_input = dequant_node_name + "_input"
68+
69+
# Input to DequantizeLinear node is the output of the Split node
70+
model.graph.node.append(
71+
onnx.helper.make_node(
72+
"DequantizeLinear",
73+
[
74+
dequant_node_input, # input
75+
match.parents[0][0].input[1], # scale
76+
match.parents[0][0].input[2], # zero point
77+
],
78+
[dequant_node_output],
79+
dequant_node_name,
80+
)
81+
)
82+
83+
# Replace the output of the Split node with the input of
84+
# the new DequantizeLinear node
85+
match.node.output[split_output_id] = dequant_node_input
86+
87+
# Set the input to the Split node to what was the input of the
88+
# original DequantizeLinear node
89+
match.node.input[0] = match.parents[0][0].input[0]
90+
91+
# Remove original DequantizeLinear node
92+
self.delete_node_deferred(match.parents[0][0])

src/sparseml/exporters/transforms/propagate_embedding_quantization.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import logging
1616

1717
import numpy
18+
import onnx.numpy_helper
1819
from onnx import ModelProto, numpy_helper
1920

2021
from sparseml.exporters.transforms.onnx_transform import OnnxTransform
@@ -79,16 +80,22 @@ def transform(self, model: ModelProto) -> ModelProto:
7980
["Concat"],
8081
],
8182
)
83+
84+
initializer_dict = {i.name: i for i in model.graph.initializer}
85+
8286
for match in matches:
8387
(gather,) = match.parents[0]
8488
dequant = match.node
85-
slice1, _, concat1 = match.children[0]
86-
slice2, _, concat2 = match.children[1]
89+
slice1, pad1, concat1 = match.children[0]
90+
slice2, pad2, concat2 = match.children[1]
8791
(concat,) = match.children[2]
8892

8993
# check for uint8 initializer
9094
indices = graph.get_init_by_name(gather.input[0])
91-
if indices is None or numpy_helper.to_array(indices).dtype != numpy.uint8:
95+
if indices is None or numpy_helper.to_array(indices).dtype not in [
96+
numpy.uint8,
97+
numpy.int8,
98+
]:
9299
continue
93100

94101
# check that all concats are the same
@@ -97,11 +104,35 @@ def transform(self, model: ModelProto) -> ModelProto:
97104

98105
self.log_match(match)
99106

100-
assert concat.input[2] == dequant.output[0]
101-
concat.input[2] = gather.output[0]
107+
for id, input_name in enumerate(concat.input):
108+
if input_name == dequant.output[0]:
109+
break
110+
111+
concat.input[id] = gather.output[0]
102112
slice1.input[0] = gather.output[0]
103113
slice2.input[0] = gather.output[0]
104114

115+
zero_point_initializer = initializer_dict[match.node.input[2]]
116+
zero_point = onnx.numpy_helper.to_array(zero_point_initializer)
117+
118+
pad1_value_initializer = initializer_dict[pad1.input[2]]
119+
pad1_value = onnx.numpy_helper.to_array(pad1_value_initializer)
120+
pad1_value = pad1_value.astype(zero_point.dtype) + zero_point
121+
new_pad1_value_initializer = numpy_helper.from_array(
122+
pad1_value, name=pad1_value_initializer.name
123+
)
124+
model.graph.initializer.remove(pad1_value_initializer)
125+
model.graph.initializer.append(new_pad1_value_initializer)
126+
127+
pad2_value_initializer = initializer_dict[pad2.input[2]]
128+
pad2_value = onnx.numpy_helper.to_array(pad2_value_initializer)
129+
pad2_value = pad2_value.astype(zero_point.dtype) + zero_point
130+
new_pad2_value_initializer = numpy_helper.from_array(
131+
pad2_value, name=pad2_value_initializer.name
132+
)
133+
model.graph.initializer.remove(pad2_value_initializer)
134+
model.graph.initializer.append(new_pad2_value_initializer)
135+
105136
tmp = concat.output[0]
106137
concat.output[0] = dequant.output[0]
107138
dequant.output[0] = tmp

tests/sparseml/exporters/transforms/test_propagate_embedding_quantization.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,18 @@ def onnx_model():
3232
"output", onnx.TensorProto.FLOAT, (1,)
3333
)
3434
scale = onnx.helper.make_tensor("scale", onnx.TensorProto.FLOAT, (1,), [1.0])
35+
zero_point = onnx.helper.make_tensor(
36+
"zero_point", onnx.TensorProto.UINT8, (1,), [128]
37+
)
3538
starts = onnx.helper.make_tensor("starts", onnx.TensorProto.INT64, (1,), [0])
3639
ends = onnx.helper.make_tensor("ends", onnx.TensorProto.INT64, (1,), [1])
3740
pads = onnx.helper.make_tensor("pads", onnx.TensorProto.INT64, (1,), [1])
41+
padding1_value = onnx.helper.make_tensor(
42+
"padding1_value", onnx.TensorProto.FLOAT, (1,), [0.0]
43+
)
44+
padding2_value = onnx.helper.make_tensor(
45+
"padding2_value", onnx.TensorProto.FLOAT, (1,), [0.0]
46+
)
3847
embeddings = onnx.helper.make_tensor(
3948
"embeddings", onnx.TensorProto.UINT8, (1,), [0]
4049
)
@@ -43,7 +52,7 @@ def onnx_model():
4352
)
4453
dequant = onnx.helper.make_node(
4554
"DequantizeLinear",
46-
["gather_output", "scale"],
55+
["gather_output", "scale", "zero_point"],
4756
["dequant_output"],
4857
name="dequant",
4958
)
@@ -52,13 +61,13 @@ def onnx_model():
5261
"Slice", ["dequant_output", "starts", "ends"], ["slice1_output"], name="slice1"
5362
)
5463
pad1 = onnx.helper.make_node(
55-
"Pad", ["slice1_output", "pads"], ["pad1_output"], name="pad1"
64+
"Pad", ["slice1_output", "pads", "padding1_value"], ["pad1_output"], name="pad1"
5665
)
5766
slice2 = onnx.helper.make_node(
5867
"Slice", ["dequant_output", "starts", "ends"], ["slice2_output"], name="slice2"
5968
)
6069
pad2 = onnx.helper.make_node(
61-
"Pad", ["slice2_output", "pads"], ["pad2_output"], name="pad2"
70+
"Pad", ["slice2_output", "pads", "padding2_value"], ["pad2_output"], name="pad2"
6271
)
6372
concat = onnx.helper.make_node(
6473
"Concat",
@@ -73,7 +82,16 @@ def onnx_model():
7382
name="g",
7483
inputs=[model_input],
7584
outputs=[model_output],
76-
initializer=[scale, starts, ends, embeddings, pads],
85+
initializer=[
86+
scale,
87+
zero_point,
88+
starts,
89+
ends,
90+
embeddings,
91+
pads,
92+
padding1_value,
93+
padding2_value,
94+
],
7795
)
7896

7997
model = onnx.helper.make_model(graph)

0 commit comments

Comments
 (0)