Skip to content

Commit 64c4b33

Browse files
authored
Arm backend: Add DecomposeCosineSimilarity (#10729)
Add pass to decompose cosine_similarity op so that it is annotated for BI profile + tests + enable disabled test in test_nn_functional.py for cosine_similarity. Signed-off-by: Elena Zhelezina <[email protected]>
1 parent fab6d7a commit 64c4b33

File tree

5 files changed

+130
-1
lines changed

5 files changed

+130
-1
lines changed

backends/arm/_passes/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2020
from .convert_to_clamp import ConvertToClampPass # noqa
2121
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
22+
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
2223
from .decompose_div_pass import DecomposeDivPass # noqa
2324
from .decompose_gelu_pass import DecomposeGeluPass # noqa
2425
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa

backends/arm/_passes/arm_pass_manager.py

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ConvertSqueezesToViewPass,
2525
ConvertToClampPass,
2626
DecomposeBatchNormPass,
27+
DecomposeCosineSimilarityPass,
2728
DecomposeDivPass,
2829
DecomposeGeluPass,
2930
DecomposeLayerNormPass,
@@ -205,6 +206,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
205206
self.add_pass(DecomposeVarPass())
206207
self.add_pass(DecomposeMeanDimPass())
207208
self.add_pass(DecomposeNotEqualPass())
209+
self.add_pass(DecomposeCosineSimilarityPass())
208210
self.add_pass(DecomposeDivPass())
209211
self.add_pass(DecomposeLeakyReLUPass())
210212
self.add_pass(DecomposeSqrtPass())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.pass_base import ExportPass
8+
9+
torch_cosine_similarity = (torch.ops.aten.cosine_similarity.default,)
10+
11+
12+
class DecomposeCosineSimilarityPass(ExportPass):
13+
"""
14+
Decomposition of aten.cosine_similarity:
15+
16+
dot = sum(mul(x1, x2), dims, keepdim=False)
17+
norm = pow( sum(mul(x, x), dims, keepdim=False), 0.5 )
18+
eps = full( (), eps_scalar )
19+
n1c = max(norm1, eps)
20+
n2c = max(norm2, eps)
21+
denom = mul(n1c, n2c)
22+
out = div(dot, denom)
23+
"""
24+
25+
def call_operator(self, op, args, kwargs, meta):
26+
if op not in torch_cosine_similarity:
27+
return super().call_operator(op, args, kwargs, meta)
28+
29+
x1, x2 = args[0], args[1]
30+
dim = kwargs.get("dim", 1)
31+
eps = kwargs.get("eps", 1e-8)
32+
dims = [dim] if isinstance(dim, int) else list(dim)
33+
34+
# 1) dot
35+
prod = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x2), {}, meta)
36+
dot = super().call_operator(
37+
torch.ops.aten.sum.dim_IntList, (prod, dims, False), {}, meta
38+
)
39+
40+
# 2a) norm1 = pow(sum(x1*x1), 0.5)
41+
x1_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x1), {}, meta)
42+
s1 = super().call_operator(
43+
torch.ops.aten.sum.dim_IntList, (x1_sq, dims, False), {}, meta
44+
)
45+
norm1 = super().call_operator(
46+
torch.ops.aten.pow.Tensor_Scalar, (s1, 0.5), {}, meta
47+
)
48+
49+
# 2b) norm2 = pow(sum(x2*x2), 0.5)
50+
x2_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x2, x2), {}, meta)
51+
s2 = super().call_operator(
52+
torch.ops.aten.sum.dim_IntList, (x2_sq, dims, False), {}, meta
53+
)
54+
norm2 = super().call_operator(
55+
torch.ops.aten.pow.Tensor_Scalar, (s2, 0.5), {}, meta
56+
)
57+
58+
# 3) eps scalar - we need to broadcast ourselves as TOSA dont do this for scalar
59+
eps_t = super().call_operator(
60+
torch.ops.aten.full_like.default, (norm1, eps), {}, meta
61+
)
62+
63+
# 4) clamp to avoid zero division
64+
n1c = super().call_operator(
65+
torch.ops.aten.maximum.default, (norm1, eps_t), {}, meta
66+
)
67+
n2c = super().call_operator(
68+
torch.ops.aten.maximum.default, (norm2, eps_t), {}, meta
69+
)
70+
71+
# 5) denom and divide
72+
denom = super().call_operator(torch.ops.aten.mul.Tensor, (n1c, n2c), {}, meta)
73+
out = super().call_operator(torch.ops.aten.div.Tensor, (dot, denom), {}, meta)
74+
75+
return out

backends/arm/test/models/test_nn_functional.py

-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def test_nn_functional_MI(test_data):
106106

107107
x_fails = {
108108
"normalize": "MLETORCH-852: Support aten.index_put.default",
109-
"cosine_similarity": "MLETORCH-854: Support aten.linalg_vector_norm.default",
110109
"unfold": "Int64 input && MLETORCH-827: Support aten.index.Tensor",
111110
"fold": "Int64 input && MLETORCH-827: Support aten.index_put.default",
112111
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm._passes.decompose_cosine_similarity_pass import (
11+
DecomposeCosineSimilarityPass,
12+
)
13+
from executorch.backends.arm.test import common
14+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
15+
16+
input_t = Tuple[torch.Tensor, torch.Tensor]
17+
18+
19+
class CosineSimilarityModel(torch.nn.Module):
20+
def get_inputs(self) -> input_t:
21+
return (torch.rand(2, 3, 4), torch.rand(2, 3, 4))
22+
23+
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
24+
return torch.cosine_similarity(x1, x2, dim=1, eps=1e-6)
25+
26+
27+
modules = {"cosine_basic": CosineSimilarityModel()}
28+
29+
30+
@common.parametrize("module", modules)
31+
def test_decompose_cosine_similarity_tosa_BI(module):
32+
33+
ops_after_pass = {
34+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 5,
35+
"executorch_exir_dialects_edge__ops_aten_sum_dim_IntList": 3,
36+
"executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 2,
37+
"executorch_exir_dialects_edge__ops_aten_full_like_default": 1,
38+
"executorch_exir_dialects_edge__ops_aten_maximum_default": 2,
39+
"executorch_exir_dialects_edge__ops_aten_reciprocal_default": 1,
40+
}
41+
42+
pipeline = PassPipeline[input_t](
43+
module,
44+
module.get_inputs(),
45+
tosa_version="TOSA-0.80+BI",
46+
ops_before_pass=None,
47+
ops_not_before_pass=None,
48+
ops_after_pass=ops_after_pass,
49+
ops_not_after_pass=None,
50+
pass_list=[DecomposeCosineSimilarityPass],
51+
)
52+
pipeline.run()

0 commit comments

Comments
 (0)