Skip to content

Commit bc3d437

Browse files
Arm backend: Change _is_ok_for_quantization to support output check (#9795)
_is_ok_for_quantization now checks the node itself as well to verify that the node can be quantized. Previously it was only checked by looking at the inputs to the node. This led to TestSplit failing, which is fixed with the change to `is_non_float_tensor` in `arm_quantizer_utils`, which now handles when node.meta["val"] is a `list` of `FakeTensor`. It traverses the list and checks if any of the elements are **not** a `FakeTensor`. If one element is not a `FakeTensor` the function will return `True`. Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 77c35f5 commit bc3d437

File tree

2 files changed

+75
-20
lines changed

2 files changed

+75
-20
lines changed

backends/arm/quantizer/arm_quantizer_utils.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# Utility functions for TOSAQuantizer
1212
#
1313

14-
from typing import cast
14+
from typing import cast, Sequence
1515

1616
import torch
1717
from torch._subclasses import FakeTensor
@@ -76,9 +76,32 @@ def is_large_scalar(node: Node, gm: GraphModule):
7676

7777

7878
def is_non_float_tensor(node: Node) -> bool:
79-
"""Check if the input is not a float tensor, so that we can skip quantization for the node
80-
since observers only works with float Tensors
79+
"""Check if the output of a node has a data type other than `torch.float32`.
80+
81+
If the output is not `torch.float32`, quantization cannot be performed, as
82+
observers only work with floating-point tensors.
83+
84+
Args:
85+
node (Node): The node to check the output(s) for.
86+
87+
Returns:
88+
bool: `True` if the data type is not float32, otherwise `False`.
89+
90+
Note:
91+
- If `node.meta["val"]` is a `list`, the function returns `True` if **any**
92+
element is **not** an instance of `FakeTensor` or does **not** have
93+
`torch.float32` as its data type.
94+
- If node.meta["val"] is missing or is not an instance of `FakeTensor`, the
95+
function returns True.
8196
"""
97+
if "val" in node.meta and isinstance(node.meta["val"], Sequence):
98+
return any(
99+
not isinstance(fake_tensor, FakeTensor)
100+
or fake_tensor.dtype != torch.float32
101+
for fake_tensor in node.meta["val"]
102+
)
103+
82104
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
83105
return True
106+
84107
return node.meta["val"].dtype != torch.float32

backends/arm/quantizer/quantization_annotator.py

+49-17
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import logging
67
import operator
78
from dataclasses import dataclass
89
from typing import Callable, List, Optional
@@ -11,13 +12,16 @@
1112
import torch.fx
1213
from executorch.backends.arm.quantizer import arm_quantizer_utils
1314
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
15+
from executorch.backends.arm.tosa_utils import get_node_debug_info
1416
from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec
1517
from torch.ao.quantization.quantizer.utils import (
1618
_annotate_input_qspec_map,
1719
_annotate_output_qspec,
1820
)
1921
from torch.fx import Node
2022

23+
logger = logging.getLogger(__name__)
24+
2125

2226
@dataclass(frozen=True)
2327
class _QuantProperty:
@@ -45,19 +49,52 @@ def _as_list(x):
4549

4650

4751
def _is_ok_for_quantization(
48-
node: Node, quant_property: _QuantProperty, gm: torch.fx.GraphModule
52+
node: Node, quant_properties: _OpQuantProperties, gm: torch.fx.GraphModule
4953
) -> bool:
50-
if quant_property.optional and (
51-
quant_property.index >= len(node.args)
52-
or node.args[quant_property.index] is None
53-
):
54-
return True
54+
"""Check if a node can be quantized.
55+
56+
A node can be quantized if:
57+
- All inputs that are required for quantization are of type `float32`
58+
and are not large scalar values.
59+
- The output of the node itself is of type `float32` and is not a large scalar.
60+
61+
Args:
62+
node (Node): The node being analyzed.
63+
quant_properties (_OpQuantProperties): Contains quantization properties for
64+
the node, including input and output quantization specifications.
65+
gm (torch.fx.GraphModule): The graph module containing the computational graph.
66+
67+
Returns:
68+
bool: `True` if the node can be quantized, otherwise `False`.
69+
"""
70+
# Check output
71+
if quant_properties.quant_output is not None:
72+
if not arm_quantizer_utils.is_ok_for_quantization(node, gm): # type: ignore[attr-defined]
73+
logger.debug(
74+
f"Could not quantize node due to output: "
75+
f"{get_node_debug_info(node, gm)}"
76+
)
5577

56-
for n_arg in _as_list(node.args[quant_property.index]):
57-
assert isinstance(n_arg, Node)
58-
if not arm_quantizer_utils.is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined]
5978
return False
6079

80+
# Check inputs
81+
for quant_property in quant_properties.quant_inputs:
82+
if quant_property.optional and (
83+
quant_property.index >= len(node.args)
84+
or node.args[quant_property.index] is None
85+
):
86+
continue
87+
88+
for n_arg in _as_list(node.args[quant_property.index]):
89+
assert isinstance(n_arg, Node)
90+
if not arm_quantizer_utils.is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined]
91+
logger.debug(
92+
f'could not quantize node due to input "{node}": '
93+
f"{get_node_debug_info(node, gm)}"
94+
)
95+
96+
return False
97+
6198
return True
6299

63100

@@ -355,14 +392,9 @@ def any_or_hardtanh_min_zero(n: Node):
355392
return quant_properties
356393

357394
# Check that each inputs/outputs can be quantized properly with the
358-
# provided QuantProperties
359-
for quant_property in quant_properties.quant_inputs:
360-
if not _is_ok_for_quantization(node, quant_property, gm):
361-
return None # type: ignore[return-value]
362-
363-
if quant_properties.quant_output is not None:
364-
if not _is_ok_for_quantization(node, quant_properties.quant_output, gm):
365-
return None # type: ignore[return-value]
395+
# provided quantization properties.
396+
if not _is_ok_for_quantization(node, quant_properties, gm):
397+
return None # type: ignore[return-value]
366398

367399
return quant_properties
368400

0 commit comments

Comments
 (0)