Skip to content

Commit 07f19af

Browse files
Merge branch 'pytorch:main' into pr_model_improve
2 parents 3fe89b1 + a073668 commit 07f19af

20 files changed

+78
-63
lines changed

backends/arm/_passes/cast_int64_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch._export.utils import is_buffer
1313

1414
logger = logging.getLogger(__name__)
15-
logger.setLevel(logging.WARNING)
1615

1716

1817
class CastInt64BuffersToInt32Pass(ExportPass):

backends/arm/arm_backend.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,13 @@
1111
# JIT compiler flows.
1212
#
1313

14-
import logging
15-
1614
from typing import List, Optional
1715

1816
from executorch.backends.arm.tosa_specification import TosaSpecification
1917

2018
from executorch.exir.backend.compile_spec_schema import CompileSpec
2119

2220

23-
logger = logging.getLogger(__name__)
24-
logger.setLevel(logging.WARNING)
25-
26-
2721
class ArmCompileSpecBuilder:
2822
def __init__(self):
2923
self.compile_spec: List[CompileSpec] = []

backends/arm/operator_support/right_shift_support.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818

1919
logger = logging.getLogger(__name__)
20-
logger.setLevel(logging.WARNING)
2120

2221

2322
@register_tosa_support_check

backends/arm/operator_support/slice_copy_support.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from executorch.exir.dialects._ops import ops as exir_ops
1717

1818
logger = logging.getLogger(__name__)
19-
logger.setLevel(logging.WARNING)
2019

2120

2221
@register_tosa_support_check

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def tosa_support_factory(
112112
# Negative checks: Remove nodes from partitioning
113113
negative_checks: list[OperatorSupportBase] = [
114114
CheckInt64Inputs(exported_program, reporter),
115+
CheckFloat64Inputs(exported_program, reporter),
115116
*[
116117
reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}")
117118
for check in (additional_checks if additional_checks else [])
@@ -443,3 +444,26 @@ def is_node_supported(
443444
)
444445
return False
445446
return True
447+
448+
449+
class CheckFloat64Inputs(OperatorSupportBase):
450+
451+
def __init__(
452+
self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter
453+
):
454+
self.reporter = reporter
455+
super().__init__()
456+
457+
def is_node_supported(
458+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
459+
) -> bool:
460+
461+
for input_node in node.all_input_nodes:
462+
tensor = get_first_fake_tensor(input_node)
463+
if tensor.dtype == torch.float64:
464+
self.reporter.report_reject(
465+
node,
466+
f"Had float64 input {input_node.name} that couldn't be handled.",
467+
)
468+
return False
469+
return True

backends/arm/operators/op_maximum.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,27 @@ def define_node(
3636
inputs: List[TosaArg],
3737
output: TosaArg,
3838
) -> None:
39-
assert inputs[0].dtype == inputs[1].dtype
39+
if inputs[0].dtype != inputs[1].dtype and inputs[0].dtype != output.dtype:
40+
raise TypeError(
41+
f"Data type of inputs and output must be the same. Got input 0 dtype: "
42+
f"{inputs[0].dtype}, input 1 dtype: {inputs[1].dtype} and output "
43+
f"dtype: {output.dtype}"
44+
)
4045

4146
scale_back = 1.0
4247
max_output = output
4348
if inputs[0].dtype == ts.DType.INT8:
4449
input_qparams = get_input_qparams(node)
45-
assert (
46-
len(input_qparams) == 2
47-
), f"Both inputs needs to have quantization information for {node}"
48-
# insert RESCALEs to int32
49-
assert (
50-
input_qparams[0] == input_qparams[1]
51-
), "Both inputs must have same quantization for MAX"
50+
if len(input_qparams) != 2:
51+
raise ValueError(
52+
f"Both inputs need to have quantization information for {node}"
53+
)
54+
if input_qparams[0] != input_qparams[1]:
55+
raise ValueError(
56+
"Both inputs must have the same quantization parameters for MAX"
57+
)
5258

59+
# insert RESCALEs to int32
5360
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
5461
tosa_graph, inputs, node
5562
)

backends/arm/operators/op_reciprocal.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,16 @@ def define_node(
3434
inputs: List[TosaArg],
3535
output: TosaArg,
3636
) -> None:
37-
assert inputs[0].dtype == output.dtype == ts.DType.FP32
37+
if len(node.all_input_nodes) != 1:
38+
raise ValueError(
39+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
40+
)
41+
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
42+
raise ValueError(
43+
f"Input and output for {self.target} need to be FP32, got "
44+
f"{inputs[0].dtype=} and {output.dtype=}"
45+
)
46+
3847
tosa_graph.addOperator(
3948
ts.TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name]
4049
)

backends/arm/operators/op_sub.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,19 @@ def define_node(
4040
) -> None:
4141
# Specification (0.80) states that input and output types
4242
# should all be the same
43-
assert inputs[0].dtype == inputs[1].dtype == output.dtype
43+
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
44+
raise TypeError(
45+
f"All IO needs to have the same data type, got input 1: "
46+
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
47+
f"{output.dtype}"
48+
)
49+
4450
# Handle int8 (quantized) and int32
45-
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]
51+
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
52+
if inputs[0].dtype not in supported_dtypes:
53+
raise TypeError(
54+
f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"'
55+
)
4656

4757
if inputs[0].dtype == ts.DType.INT8:
4858
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
@@ -97,15 +107,27 @@ def define_node(
97107
) -> None:
98108
# Specification (0.80) states that input and output types
99109
# should all be the same
100-
assert inputs[0].dtype == inputs[1].dtype == output.dtype
110+
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
111+
raise TypeError(
112+
f"All IO needs to have the same data type, got input 1: "
113+
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
114+
f"{output.dtype}"
115+
)
101116

102117
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
103118
# Call the inherited define_node for handling integers
104119
super().define_node(node, tosa_graph, inputs, output)
105120
else:
106121
# FP32 Sub lowering
107-
assert inputs[0].dtype == ts.DType.FP32
108-
assert output.dtype == ts.DType.FP32
122+
if (
123+
inputs[0].dtype != ts.DType.FP32
124+
or inputs[1].dtype != ts.DType.FP32
125+
or output.dtype != ts.DType.FP32
126+
):
127+
raise TypeError(
128+
f"All IO needs to have data type fp32. Got: {inputs[0].dtype}, "
129+
f"input 2: {inputs[1].dtype} and output: {output.dtype}"
130+
)
109131

110132
# MI lowering
111133
tosa_graph.addOperator(

backends/arm/test/misc/test_debug_feats.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import logging
87
import os
98
import shutil
109
import tempfile
@@ -15,9 +14,6 @@
1514

1615
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1716

18-
logger = logging.getLogger(__name__)
19-
logger.setLevel(logging.INFO)
20-
2117

2218
class Linear(torch.nn.Module):
2319
def __init__(
@@ -205,7 +201,6 @@ def test_collate_tosa_BI_tests(self):
205201

206202

207203
def test_dump_tosa_ops(caplog):
208-
caplog.set_level(logging.INFO)
209204
model = Linear(20, 30)
210205
(
211206
ArmTester(
@@ -222,7 +217,6 @@ def test_dump_tosa_ops(caplog):
222217

223218

224219
def test_fail_dump_tosa_ops(caplog):
225-
caplog.set_level(logging.INFO)
226220

227221
class Add(torch.nn.Module):
228222
def forward(self, x):

backends/arm/test/models/test_conformer.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
import logging
76
import unittest
87

98
import torch
@@ -14,10 +13,6 @@
1413
from torchaudio.models import Conformer
1514

1615

17-
logger = logging.getLogger(__name__)
18-
logger.setLevel(logging.INFO)
19-
20-
2116
def get_test_inputs(dim, lengths, num_examples):
2217
return (torch.rand(num_examples, int(lengths.max()), dim), lengths)
2318

backends/arm/test/models/test_llama.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
sys.path.append(project_dir)
2929

3030
logger = logging.getLogger(__name__)
31-
logger.setLevel(logging.INFO)
3231

3332

3433
class TestLlama(unittest.TestCase):

backends/arm/test/models/test_w2l_arm.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
import logging
98
import unittest
109
from typing import Tuple
1110

@@ -19,10 +18,6 @@
1918
from torchaudio import models
2019

2120

22-
logger = logging.getLogger(__name__)
23-
logger.setLevel(logging.INFO)
24-
25-
2621
def get_test_inputs(batch_size, num_features, input_frames):
2722
return (torch.randn(batch_size, num_features, input_frames),)
2823

backends/arm/test/ops/test_batch_norm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
import logging
98
import unittest
109

1110
from typing import Tuple
@@ -15,8 +14,6 @@
1514
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1615
from parameterized import parameterized
1716

18-
logger = logging.getLogger(__name__)
19-
logger.setLevel(logging.INFO)
2017

2118
test_data_suite = [
2219
# (test_name, test_data, [num_features, affine, track_running_stats, weight, bias, running_mean, running_var,] )

backends/arm/test/ops/test_conv_combos.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import logging
87
import unittest
98

109
from typing import Tuple
@@ -18,8 +17,6 @@
1817
from parameterized import parameterized
1918
from torch.nn.parameter import Parameter
2019

21-
logger = logging.getLogger(__name__)
22-
logger.setLevel(logging.INFO)
2320

2421
"""
2522
This file contain unit tests where conv are combined with other ops.

backends/arm/test/ops/test_div.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
import logging
98
import unittest
109

1110
from typing import Optional, Tuple, Union
@@ -17,8 +16,6 @@
1716
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1817
from parameterized import parameterized
1918

20-
logger = logging.getLogger(__name__)
21-
logger.setLevel(logging.INFO)
2219

2320
test_data_suite = [
2421
# (test_name, input, other, rounding_mode) See torch.div() for info

backends/arm/test/ops/test_linear.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
import logging
98
import unittest
109

1110
from typing import Tuple
@@ -19,9 +18,6 @@
1918
from executorch.exir.backend.compile_spec_schema import CompileSpec
2019
from parameterized import parameterized
2120

22-
logger = logging.getLogger(__name__)
23-
logger.setLevel(logging.INFO)
24-
2521

2622
test_data_suite_rank1 = [
2723
# (test_name, test_data, out_features, has_bias)

backends/arm/test/ops/test_max_pool.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
import logging
98
import unittest
109

1110
from typing import Tuple
@@ -26,8 +25,6 @@
2625
from executorch.exir.backend.backend_details import CompileSpec
2726
from parameterized import parameterized
2827

29-
logger = logging.getLogger(__name__)
30-
logger.setLevel(logging.INFO)
3128

3229
test_data_suite = [
3330
# (test_name, test_data, [kernel_size, stride, padding])

backends/arm/test/ops/test_sigmoid.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024 Arm Limited and/or its affiliates.
32
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
import logging
98
import unittest
109

1110
from typing import Tuple
@@ -16,9 +15,6 @@
1615
from executorch.exir.backend.compile_spec_schema import CompileSpec
1716
from parameterized import parameterized
1817

19-
logger = logging.getLogger(__name__)
20-
logger.setLevel(logging.INFO)
21-
2218

2319
test_data_suite = [
2420
# (test_name, test_data)

backends/arm/test/runner_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import numpy as np
1919
import torch
2020

21-
logger = logging.getLogger(__name__)
2221
try:
2322
import tosa_tools.v0_80.tosa_reference_model as tosa_reference_model
2423
except ImportError:
@@ -37,7 +36,6 @@
3736
from tosa_tools.v0_80.tosa import TosaGraph
3837

3938
logger = logging.getLogger(__name__)
40-
logger.setLevel(logging.CRITICAL)
4139

4240
# Copied from PyTorch.
4341
# From torch/testing/_internal/common_utils.py:torch_to_numpy_dtype_dict

backends/arm/util/arm_model_evaluator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
# Logger for outputting progress for longer running evaluation
2626
logger = logging.getLogger(__name__)
27+
# Explicitly set logging level: MLETORCH-893
2728
logger.setLevel(logging.INFO)
2829

2930

0 commit comments

Comments
 (0)