Skip to content

Commit 7516dfb

Browse files
committed
NXP backend: Add NeutronQuantizer
1 parent 4717459 commit 7516dfb

File tree

5 files changed

+1209
-0
lines changed

5 files changed

+1209
-0
lines changed

Diff for: backends/nxp/quantizer/neutron_quantizer.py

+205
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024-2025 NXP
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import List, Optional, Tuple, Union
8+
9+
import torch
10+
11+
from executorch.backends.nxp.quantizer.patterns import (
12+
AddmmPattern,
13+
AvgPoolPattern,
14+
Conv1dPattern,
15+
Conv2dPattern,
16+
LinearPattern,
17+
MaxPoolPattern,
18+
PadPattern,
19+
PermutePattern,
20+
QuantizationPattern,
21+
ReluInPlacePattern,
22+
ReluPattern,
23+
ReshapePattern,
24+
SoftMaxPattern,
25+
)
26+
from executorch.backends.nxp.quantizer.utils import (
27+
find_sequential_partitions_aten,
28+
is_annotated,
29+
no_outside_users,
30+
)
31+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import (
32+
OperatorConfig,
33+
QuantizationAnnotation,
34+
QuantizationConfig,
35+
QuantizationSpec,
36+
)
37+
from torch import fx
38+
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
39+
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer
40+
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
41+
42+
43+
class NeutronAtenQuantizer(Quantizer):
44+
def __init__(
45+
self, pattern: QuantizationPattern, quantization_config: QuantizationConfig
46+
) -> None:
47+
super().__init__()
48+
self.pattern = pattern
49+
self.quantization_config = quantization_config
50+
51+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
52+
fused_partitions = find_sequential_partitions_aten(
53+
model,
54+
self.pattern.partition_types(),
55+
)
56+
57+
input_act_qspec = self.quantization_config.input_activation
58+
weight_qspec = self.quantization_config.weight
59+
bias_qspec = self.quantization_config.bias
60+
output_act_qspec = self.quantization_config.output_activation
61+
62+
for fused_partition in fused_partitions:
63+
if not no_outside_users(fused_partition):
64+
continue
65+
66+
anchors = self.pattern.get_anchors(model, fused_partition)
67+
if not anchors or anchors.empty:
68+
continue
69+
if is_annotated(
70+
[
71+
x[0]
72+
for x in anchors.inputs
73+
+ anchors.weights
74+
+ anchors.biases
75+
+ anchors.output
76+
]
77+
):
78+
continue
79+
80+
for output, *custom_spec in anchors.output:
81+
# pyre-ignore[16]: no attribute
82+
output.meta["quantization_annotation"] = QuantizationAnnotation(
83+
# pyre-ignore[6]: incompatible parameter type
84+
output_qspec=(custom_spec[0] if custom_spec else output_act_qspec),
85+
_annotated=True,
86+
)
87+
88+
def annotate_inputs(
89+
inputs: Union[
90+
List[Tuple[fx.Node, int]],
91+
List[Tuple[fx.Node, int, DerivedQuantizationSpec],],
92+
],
93+
spec: Optional[QuantizationSpec],
94+
) -> None:
95+
for node, idx, *custom_spec in inputs:
96+
# pyre-ignore[16]: no attribute
97+
annotation = node.meta.get(
98+
"quantization_annotation",
99+
QuantizationAnnotation(_annotated=True),
100+
)
101+
arg = (
102+
# pyre-ignore[16]: no attribute
103+
node.args[idx]
104+
if isinstance(idx, int)
105+
# pyre-ignore[16]: no attribute
106+
else node.args[idx[0]][idx[1]]
107+
)
108+
annotation.input_qspec_map[arg] = (
109+
custom_spec[0] if custom_spec else spec
110+
)
111+
# pyre-ignore[16]: no attribute
112+
node.meta["quantization_annotation"] = annotation
113+
114+
def annotate_weights_or_biases(
115+
weights_or_biases: List[Tuple[fx.Node, int]],
116+
spec: Optional[QuantizationSpec],
117+
) -> None:
118+
for node, idx, *custom_spec in weights_or_biases:
119+
annotation = node.meta.get(
120+
"quantization_annotation",
121+
QuantizationAnnotation(_annotated=True),
122+
)
123+
annotation.input_qspec_map[node.args[idx]] = (
124+
custom_spec[0] if custom_spec else spec
125+
)
126+
node.meta["quantization_annotation"] = annotation
127+
128+
# pyre-ignore[6]: incompatible parameter type
129+
annotate_inputs(anchors.inputs, input_act_qspec)
130+
annotate_weights_or_biases(anchors.weights, weight_qspec)
131+
# pyre-ignore[6]: incompatible parameter type
132+
annotate_weights_or_biases(anchors.biases, bias_qspec)
133+
return model
134+
135+
def validate(self, model: fx.GraphModule) -> None:
136+
pass
137+
138+
@classmethod
139+
def get_supported_operators(cls) -> List[OperatorConfig]:
140+
return []
141+
142+
143+
# Quantization Specification used by Neutron NPU
144+
act_qspec = QuantizationSpec(
145+
dtype=torch.int8,
146+
quant_min=-128,
147+
quant_max=127,
148+
qscheme=torch.per_tensor_affine,
149+
is_dynamic=False,
150+
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
151+
)
152+
153+
wgt_qspec = QuantizationSpec(
154+
dtype=torch.int8,
155+
quant_min=-127,
156+
quant_max=127,
157+
qscheme=torch.per_tensor_symmetric,
158+
is_dynamic=False,
159+
observer_or_fake_quant_ctr=MinMaxObserver,
160+
ch_axis=0,
161+
)
162+
163+
wgt_fc_qspec = QuantizationSpec(
164+
dtype=torch.int8,
165+
quant_min=-127,
166+
quant_max=127,
167+
qscheme=torch.per_tensor_symmetric,
168+
is_dynamic=False,
169+
observer_or_fake_quant_ctr=MinMaxObserver,
170+
)
171+
172+
# Is set by the *PatternQuantizer directly.
173+
bias_qspec = None
174+
175+
176+
class NeutronQuantizer(ComposableQuantizer):
177+
def __init__(self):
178+
static_qconfig = QuantizationConfig(
179+
act_qspec,
180+
act_qspec,
181+
wgt_qspec,
182+
None,
183+
)
184+
static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None)
185+
super().__init__(
186+
[
187+
NeutronAtenQuantizer(AddmmPattern(), static_fc_qconfig),
188+
NeutronAtenQuantizer(Conv1dPattern(), static_qconfig),
189+
NeutronAtenQuantizer(Conv2dPattern(), static_qconfig),
190+
NeutronAtenQuantizer(LinearPattern(), static_fc_qconfig),
191+
NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig),
192+
NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig),
193+
NeutronAtenQuantizer(ReshapePattern(), static_qconfig),
194+
NeutronAtenQuantizer(PermutePattern(), static_qconfig),
195+
NeutronAtenQuantizer(PadPattern(), static_qconfig),
196+
NeutronAtenQuantizer(ReluPattern(), static_qconfig),
197+
NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig),
198+
NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig),
199+
]
200+
)
201+
202+
def transform_for_annotation(
203+
self, model: torch.fx.GraphModule
204+
) -> torch.fx.GraphModule:
205+
return model

0 commit comments

Comments
 (0)