Skip to content

Commit 41cfb0a

Browse files
committed
NXP backend: Add NeutronQuantizer
1 parent 4717459 commit 41cfb0a

File tree

5 files changed

+1211
-0
lines changed

5 files changed

+1211
-0
lines changed

Diff for: backends/nxp/quantizer/neutron_quantizer.py

+206
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024-2025 NXP
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import List, Optional, Tuple, Union
8+
9+
import torch
10+
11+
from executorch.backends.nxp.quantizer.patterns import (
12+
AddmmPattern,
13+
AvgPoolPattern,
14+
Conv1dPattern,
15+
Conv2dPattern,
16+
LinearPattern,
17+
MaxPoolPattern,
18+
PadPattern,
19+
PermutePattern,
20+
QuantizationPattern,
21+
ReluInPlacePattern,
22+
ReluPattern,
23+
ReshapePattern,
24+
SoftMaxPattern,
25+
)
26+
from executorch.backends.nxp.quantizer.utils import (
27+
find_sequential_partitions_aten,
28+
is_annotated,
29+
no_outside_users,
30+
)
31+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import (
32+
OperatorConfig,
33+
QuantizationAnnotation,
34+
QuantizationConfig,
35+
QuantizationSpec,
36+
)
37+
from executorch.src.executorch.backends.nxp.quantizer.patterns import AddmmPattern
38+
from torch import fx
39+
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
40+
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer
41+
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
42+
43+
44+
class NeutronAtenQuantizer(Quantizer):
45+
def __init__(
46+
self, pattern: QuantizationPattern, quantization_config: QuantizationConfig
47+
) -> None:
48+
super().__init__()
49+
self.pattern = pattern
50+
self.quantization_config = quantization_config
51+
52+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
53+
fused_partitions = find_sequential_partitions_aten(
54+
model,
55+
self.pattern.partition_types(),
56+
)
57+
58+
input_act_qspec = self.quantization_config.input_activation
59+
weight_qspec = self.quantization_config.weight
60+
bias_qspec = self.quantization_config.bias
61+
output_act_qspec = self.quantization_config.output_activation
62+
63+
for fused_partition in fused_partitions:
64+
if not no_outside_users(fused_partition):
65+
continue
66+
67+
anchors = self.pattern.get_anchors(model, fused_partition)
68+
if not anchors or anchors.empty:
69+
continue
70+
if is_annotated(
71+
[
72+
x[0]
73+
for x in anchors.inputs
74+
+ anchors.weights
75+
+ anchors.biases
76+
+ anchors.output
77+
]
78+
):
79+
continue
80+
81+
for output, *custom_spec in anchors.output:
82+
# pyre-ignore[16]: no attribute
83+
output.meta["quantization_annotation"] = QuantizationAnnotation(
84+
# pyre-ignore[6]: incompatible parameter type
85+
output_qspec=(custom_spec[0] if custom_spec else output_act_qspec),
86+
_annotated=True,
87+
)
88+
89+
def annotate_inputs(
90+
inputs: Union[
91+
List[Tuple[fx.Node, int]],
92+
List[Tuple[fx.Node, int, DerivedQuantizationSpec],],
93+
],
94+
spec: Optional[QuantizationSpec],
95+
) -> None:
96+
for node, idx, *custom_spec in inputs:
97+
# pyre-ignore[16]: no attribute
98+
annotation = node.meta.get(
99+
"quantization_annotation",
100+
QuantizationAnnotation(_annotated=True),
101+
)
102+
arg = (
103+
# pyre-ignore[16]: no attribute
104+
node.args[idx]
105+
if isinstance(idx, int)
106+
# pyre-ignore[16]: no attribute
107+
else node.args[idx[0]][idx[1]]
108+
)
109+
annotation.input_qspec_map[arg] = (
110+
custom_spec[0] if custom_spec else spec
111+
)
112+
# pyre-ignore[16]: no attribute
113+
node.meta["quantization_annotation"] = annotation
114+
115+
def annotate_weights_or_biases(
116+
weights_or_biases: List[Tuple[fx.Node, int]],
117+
spec: Optional[QuantizationSpec],
118+
) -> None:
119+
for node, idx, *custom_spec in weights_or_biases:
120+
annotation = node.meta.get(
121+
"quantization_annotation",
122+
QuantizationAnnotation(_annotated=True),
123+
)
124+
annotation.input_qspec_map[node.args[idx]] = (
125+
custom_spec[0] if custom_spec else spec
126+
)
127+
node.meta["quantization_annotation"] = annotation
128+
129+
# pyre-ignore[6]: incompatible parameter type
130+
annotate_inputs(anchors.inputs, input_act_qspec)
131+
annotate_weights_or_biases(anchors.weights, weight_qspec)
132+
# pyre-ignore[6]: incompatible parameter type
133+
annotate_weights_or_biases(anchors.biases, bias_qspec)
134+
return model
135+
136+
def validate(self, model: fx.GraphModule) -> None:
137+
pass
138+
139+
@classmethod
140+
def get_supported_operators(cls) -> List[OperatorConfig]:
141+
return []
142+
143+
144+
# Quantization Specification used by Neutron NPU
145+
act_qspec = QuantizationSpec(
146+
dtype=torch.int8,
147+
quant_min=-128,
148+
quant_max=127,
149+
qscheme=torch.per_tensor_affine,
150+
is_dynamic=False,
151+
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
152+
)
153+
154+
wgt_qspec = QuantizationSpec(
155+
dtype=torch.int8,
156+
quant_min=-127,
157+
quant_max=127,
158+
qscheme=torch.per_tensor_symmetric,
159+
is_dynamic=False,
160+
observer_or_fake_quant_ctr=MinMaxObserver,
161+
ch_axis=0,
162+
)
163+
164+
wgt_fc_qspec = QuantizationSpec(
165+
dtype=torch.int8,
166+
quant_min=-127,
167+
quant_max=127,
168+
qscheme=torch.per_tensor_symmetric,
169+
is_dynamic=False,
170+
observer_or_fake_quant_ctr=MinMaxObserver,
171+
)
172+
173+
# Is set by the *PatternQuantizer directly.
174+
bias_qspec = None
175+
176+
177+
class NeutronQuantizer(ComposableQuantizer):
178+
def __init__(self):
179+
static_qconfig = QuantizationConfig(
180+
act_qspec,
181+
act_qspec,
182+
wgt_qspec,
183+
None,
184+
)
185+
static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None)
186+
super().__init__(
187+
[
188+
NeutronAtenQuantizer(AddmmPattern(), static_fc_qconfig),
189+
NeutronAtenQuantizer(Conv1dPattern(), static_qconfig),
190+
NeutronAtenQuantizer(Conv2dPattern(), static_qconfig),
191+
NeutronAtenQuantizer(LinearPattern(), static_fc_qconfig),
192+
NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig),
193+
NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig),
194+
NeutronAtenQuantizer(ReshapePattern(), static_qconfig),
195+
NeutronAtenQuantizer(PermutePattern(), static_qconfig),
196+
NeutronAtenQuantizer(PadPattern(), static_qconfig),
197+
NeutronAtenQuantizer(ReluPattern(), static_qconfig),
198+
NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig),
199+
NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig),
200+
]
201+
)
202+
203+
def transform_for_annotation(
204+
self, model: torch.fx.GraphModule
205+
) -> torch.fx.GraphModule:
206+
return model

0 commit comments

Comments
 (0)