|
3 | 3 | # This source code is licensed under the BSD-style license found in the
|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
5 | 5 |
|
| 6 | +import logging |
6 | 7 | import operator
|
7 | 8 | from dataclasses import dataclass
|
8 | 9 | from typing import Callable, List, Optional
|
|
11 | 12 | import torch.fx
|
12 | 13 | from executorch.backends.arm.quantizer import arm_quantizer_utils
|
13 | 14 | from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
|
| 15 | +from executorch.backends.arm.tosa_utils import get_node_debug_info |
14 | 16 | from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec
|
15 | 17 | from torch.ao.quantization.quantizer.utils import (
|
16 | 18 | _annotate_input_qspec_map,
|
17 | 19 | _annotate_output_qspec,
|
18 | 20 | )
|
19 | 21 | from torch.fx import Node
|
20 | 22 |
|
| 23 | +logger = logging.getLogger(__name__) |
| 24 | + |
21 | 25 |
|
22 | 26 | @dataclass(frozen=True)
|
23 | 27 | class _QuantProperty:
|
@@ -45,19 +49,52 @@ def _as_list(x):
|
45 | 49 |
|
46 | 50 |
|
47 | 51 | def _is_ok_for_quantization(
|
48 |
| - node: Node, quant_property: _QuantProperty, gm: torch.fx.GraphModule |
| 52 | + node: Node, quant_properties: _OpQuantProperties, gm: torch.fx.GraphModule |
49 | 53 | ) -> bool:
|
50 |
| - if quant_property.optional and ( |
51 |
| - quant_property.index >= len(node.args) |
52 |
| - or node.args[quant_property.index] is None |
53 |
| - ): |
54 |
| - return True |
| 54 | + """Check if a node can be quantized. |
| 55 | +
|
| 56 | + A node can be quantized if: |
| 57 | + - All inputs that are required for quantization are of type `float32` |
| 58 | + and are not large scalar values. |
| 59 | + - The output of the node itself is of type `float32` and is not a large scalar. |
| 60 | +
|
| 61 | + Args: |
| 62 | + node (Node): The node being analyzed. |
| 63 | + quant_properties (_OpQuantProperties): Contains quantization properties for |
| 64 | + the node, including input and output quantization specifications. |
| 65 | + gm (torch.fx.GraphModule): The graph module containing the computational graph. |
| 66 | +
|
| 67 | + Returns: |
| 68 | + bool: `True` if the node can be quantized, otherwise `False`. |
| 69 | + """ |
| 70 | + # Check output |
| 71 | + if quant_properties.quant_output is not None: |
| 72 | + if not arm_quantizer_utils.is_ok_for_quantization(node, gm): # type: ignore[attr-defined] |
| 73 | + logger.debug( |
| 74 | + f"Could not quantize node due to output: " |
| 75 | + f"{get_node_debug_info(node, gm)}" |
| 76 | + ) |
55 | 77 |
|
56 |
| - for n_arg in _as_list(node.args[quant_property.index]): |
57 |
| - assert isinstance(n_arg, Node) |
58 |
| - if not arm_quantizer_utils.is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined] |
59 | 78 | return False
|
60 | 79 |
|
| 80 | + # Check inputs |
| 81 | + for quant_property in quant_properties.quant_inputs: |
| 82 | + if quant_property.optional and ( |
| 83 | + quant_property.index >= len(node.args) |
| 84 | + or node.args[quant_property.index] is None |
| 85 | + ): |
| 86 | + continue |
| 87 | + |
| 88 | + for n_arg in _as_list(node.args[quant_property.index]): |
| 89 | + assert isinstance(n_arg, Node) |
| 90 | + if not arm_quantizer_utils.is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined] |
| 91 | + logger.debug( |
| 92 | + f'could not quantize node due to input "{node}": ' |
| 93 | + f"{get_node_debug_info(node, gm)}" |
| 94 | + ) |
| 95 | + |
| 96 | + return False |
| 97 | + |
61 | 98 | return True
|
62 | 99 |
|
63 | 100 |
|
@@ -355,14 +392,9 @@ def any_or_hardtanh_min_zero(n: Node):
|
355 | 392 | return quant_properties
|
356 | 393 |
|
357 | 394 | # Check that each inputs/outputs can be quantized properly with the
|
358 |
| - # provided QuantProperties |
359 |
| - for quant_property in quant_properties.quant_inputs: |
360 |
| - if not _is_ok_for_quantization(node, quant_property, gm): |
361 |
| - return None # type: ignore[return-value] |
362 |
| - |
363 |
| - if quant_properties.quant_output is not None: |
364 |
| - if not _is_ok_for_quantization(node, quant_properties.quant_output, gm): |
365 |
| - return None # type: ignore[return-value] |
| 395 | + # provided quantization properties. |
| 396 | + if not _is_ok_for_quantization(node, quant_properties, gm): |
| 397 | + return None # type: ignore[return-value] |
366 | 398 |
|
367 | 399 | return quant_properties
|
368 | 400 |
|
|
0 commit comments