Skip to content

Commit 2d4bfd7

Browse files
committed
Update on "Xnnpack test for program-data separation"
Add xnnpack test for program-data separation Differential Revision: [D73794695](https://our.internmc.facebook.com/intern/diff/D73794695/) [ghstack-poisoned]
2 parents 59e5222 + 8b4c0ed commit 2d4bfd7

File tree

12 files changed

+246
-114
lines changed

12 files changed

+246
-114
lines changed

backends/arm/_passes/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2020
from .convert_to_clamp import ConvertToClampPass # noqa
2121
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
22+
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
2223
from .decompose_div_pass import DecomposeDivPass # noqa
2324
from .decompose_gelu_pass import DecomposeGeluPass # noqa
2425
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa

backends/arm/_passes/arm_pass_manager.py

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ConvertSqueezesToViewPass,
2525
ConvertToClampPass,
2626
DecomposeBatchNormPass,
27+
DecomposeCosineSimilarityPass,
2728
DecomposeDivPass,
2829
DecomposeGeluPass,
2930
DecomposeLayerNormPass,
@@ -205,6 +206,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
205206
self.add_pass(DecomposeVarPass())
206207
self.add_pass(DecomposeMeanDimPass())
207208
self.add_pass(DecomposeNotEqualPass())
209+
self.add_pass(DecomposeCosineSimilarityPass())
208210
self.add_pass(DecomposeDivPass())
209211
self.add_pass(DecomposeLeakyReLUPass())
210212
self.add_pass(DecomposeSqrtPass())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.pass_base import ExportPass
8+
9+
torch_cosine_similarity = (torch.ops.aten.cosine_similarity.default,)
10+
11+
12+
class DecomposeCosineSimilarityPass(ExportPass):
13+
"""
14+
Decomposition of aten.cosine_similarity:
15+
16+
dot = sum(mul(x1, x2), dims, keepdim=False)
17+
norm = pow( sum(mul(x, x), dims, keepdim=False), 0.5 )
18+
eps = full( (), eps_scalar )
19+
n1c = max(norm1, eps)
20+
n2c = max(norm2, eps)
21+
denom = mul(n1c, n2c)
22+
out = div(dot, denom)
23+
"""
24+
25+
def call_operator(self, op, args, kwargs, meta):
26+
if op not in torch_cosine_similarity:
27+
return super().call_operator(op, args, kwargs, meta)
28+
29+
x1, x2 = args[0], args[1]
30+
dim = kwargs.get("dim", 1)
31+
eps = kwargs.get("eps", 1e-8)
32+
dims = [dim] if isinstance(dim, int) else list(dim)
33+
34+
# 1) dot
35+
prod = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x2), {}, meta)
36+
dot = super().call_operator(
37+
torch.ops.aten.sum.dim_IntList, (prod, dims, False), {}, meta
38+
)
39+
40+
# 2a) norm1 = pow(sum(x1*x1), 0.5)
41+
x1_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x1), {}, meta)
42+
s1 = super().call_operator(
43+
torch.ops.aten.sum.dim_IntList, (x1_sq, dims, False), {}, meta
44+
)
45+
norm1 = super().call_operator(
46+
torch.ops.aten.pow.Tensor_Scalar, (s1, 0.5), {}, meta
47+
)
48+
49+
# 2b) norm2 = pow(sum(x2*x2), 0.5)
50+
x2_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x2, x2), {}, meta)
51+
s2 = super().call_operator(
52+
torch.ops.aten.sum.dim_IntList, (x2_sq, dims, False), {}, meta
53+
)
54+
norm2 = super().call_operator(
55+
torch.ops.aten.pow.Tensor_Scalar, (s2, 0.5), {}, meta
56+
)
57+
58+
# 3) eps scalar - we need to broadcast ourselves as TOSA dont do this for scalar
59+
eps_t = super().call_operator(
60+
torch.ops.aten.full_like.default, (norm1, eps), {}, meta
61+
)
62+
63+
# 4) clamp to avoid zero division
64+
n1c = super().call_operator(
65+
torch.ops.aten.maximum.default, (norm1, eps_t), {}, meta
66+
)
67+
n2c = super().call_operator(
68+
torch.ops.aten.maximum.default, (norm2, eps_t), {}, meta
69+
)
70+
71+
# 5) denom and divide
72+
denom = super().call_operator(torch.ops.aten.mul.Tensor, (n1c, n2c), {}, meta)
73+
out = super().call_operator(torch.ops.aten.div.Tensor, (dot, denom), {}, meta)
74+
75+
return out

backends/arm/test/models/test_nn_functional.py

-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def test_nn_functional_MI(test_data):
106106

107107
x_fails = {
108108
"normalize": "MLETORCH-852: Support aten.index_put.default",
109-
"cosine_similarity": "MLETORCH-854: Support aten.linalg_vector_norm.default",
110109
"unfold": "Int64 input && MLETORCH-827: Support aten.index.Tensor",
111110
"fold": "Int64 input && MLETORCH-827: Support aten.index_put.default",
112111
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm._passes.decompose_cosine_similarity_pass import (
11+
DecomposeCosineSimilarityPass,
12+
)
13+
from executorch.backends.arm.test import common
14+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
15+
16+
input_t = Tuple[torch.Tensor, torch.Tensor]
17+
18+
19+
class CosineSimilarityModel(torch.nn.Module):
20+
def get_inputs(self) -> input_t:
21+
return (torch.rand(2, 3, 4), torch.rand(2, 3, 4))
22+
23+
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
24+
return torch.cosine_similarity(x1, x2, dim=1, eps=1e-6)
25+
26+
27+
modules = {"cosine_basic": CosineSimilarityModel()}
28+
29+
30+
@common.parametrize("module", modules)
31+
def test_decompose_cosine_similarity_tosa_BI(module):
32+
33+
ops_after_pass = {
34+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 5,
35+
"executorch_exir_dialects_edge__ops_aten_sum_dim_IntList": 3,
36+
"executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 2,
37+
"executorch_exir_dialects_edge__ops_aten_full_like_default": 1,
38+
"executorch_exir_dialects_edge__ops_aten_maximum_default": 2,
39+
"executorch_exir_dialects_edge__ops_aten_reciprocal_default": 1,
40+
}
41+
42+
pipeline = PassPipeline[input_t](
43+
module,
44+
module.get_inputs(),
45+
tosa_version="TOSA-0.80+BI",
46+
ops_before_pass=None,
47+
ops_not_before_pass=None,
48+
ops_after_pass=ops_after_pass,
49+
ops_not_after_pass=None,
50+
pass_list=[DecomposeCosineSimilarityPass],
51+
)
52+
pipeline.run()

backends/xnnpack/test/runtime/test_xnn_data_separation.cpp

+104-104
Original file line numberDiff line numberDiff line change
@@ -6,109 +6,109 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/extension/data_loader/file_data_loader.h>
10-
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
11-
#include <executorch/runtime/core/error.h>
12-
#include <executorch/runtime/core/result.h>
13-
#include <executorch/runtime/executor/method.h>
14-
#include <executorch/runtime/executor/program.h>
15-
#include <executorch/runtime/executor/test/managed_memory_manager.h>
16-
#include <executorch/runtime/platform/runtime.h>
17-
18-
#include <gtest/gtest.h>
19-
20-
using namespace ::testing;
21-
using executorch::extension::FlatTensorDataMap;
22-
using executorch::runtime::DataLoader;
23-
using executorch::runtime::Error;
24-
using executorch::runtime::FreeableBuffer;
25-
using executorch::runtime::Method;
26-
using executorch::runtime::Program;
27-
using executorch::runtime::Result;
28-
using executorch::runtime::testing::ManagedMemoryManager;
29-
using torch::executor::util::FileDataLoader;
30-
31-
constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U;
32-
constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U;
33-
34-
class DataSeparationTest : public ::testing::Test {
35-
protected:
36-
void SetUp() override {
37-
// Since these tests cause ET_LOG to be called, the PAL must be initialized
38-
// first.
39-
executorch::runtime::runtime_init();
40-
41-
// Create data loaders.
42-
Result<FileDataLoader> linear_program_loader =
43-
FileDataLoader::from(std::getenv("ET_MODULE_LINEAR_XNN_PROGRAM_PATH"));
44-
ASSERT_EQ(linear_program_loader.error(), Error::Ok);
45-
linear_program_loader_ = std::make_unique<FileDataLoader>(
46-
std::move(linear_program_loader.get()));
47-
48-
Result<FileDataLoader> linear_data_loader =
49-
FileDataLoader::from(std::getenv("ET_MODULE_LINEAR_XNN_DATA_PATH"));
50-
ASSERT_EQ(linear_data_loader.error(), Error::Ok);
51-
linear_data_loader_ =
52-
std::make_unique<FileDataLoader>(std::move(linear_data_loader.get()));
53-
54-
// Create programs.
55-
Result<Program> linear_program = Program::load(
56-
linear_program_loader_.get(),
57-
Program::Verification::InternalConsistency);
58-
ASSERT_EQ(linear_program.error(), Error::Ok);
59-
linear_program_ =
60-
std::make_unique<Program>(std::move(linear_program.get()));
61-
62-
Result<FlatTensorDataMap> linear_data_map =
63-
FlatTensorDataMap::load(linear_data_loader_.get());
64-
EXPECT_EQ(linear_data_map.error(), Error::Ok);
65-
linear_data_map_ =
66-
std::make_unique<FlatTensorDataMap>(std::move(linear_data_map.get()));
67-
}
68-
69-
private:
70-
std::unique_ptr<FileDataLoader> linear_program_loader_;
71-
std::unique_ptr<FileDataLoader> linear_data_loader_;
72-
73-
protected:
74-
std::unique_ptr<Program> linear_program_;
75-
std::unique_ptr<FlatTensorDataMap> linear_data_map_;
76-
};
77-
78-
TEST_F(DataSeparationTest, TestExternalData) {
79-
FlatTensorDataMap* data_map = linear_data_map_.get();
80-
EXPECT_EQ(data_map->get_num_keys().get(), 2);
81-
82-
Result<const char*> key0 = data_map->get_key(0);
83-
EXPECT_EQ(key0.error(), Error::Ok);
84-
Result<const char*> key1 = data_map->get_key(1);
85-
EXPECT_EQ(key1.error(), Error::Ok);
86-
87-
// Check that accessing keys out of bounds fails.
88-
EXPECT_EQ(data_map->get_key(2).error(), Error::InvalidArgument);
89-
90-
// Linear.weight
91-
Result<FreeableBuffer> data0 = data_map->get_data(key0.get());
92-
EXPECT_EQ(data0.error(), Error::Ok);
93-
EXPECT_EQ(data0.get().size(), 36); // 3*3*4 (3*3 matrix, 4 bytes per float)
94-
95-
// Linear.bias
96-
Result<FreeableBuffer> data1 = data_map->get_data(key1.get());
97-
EXPECT_EQ(data1.error(), Error::Ok);
98-
EXPECT_EQ(data1.get().size(), 12); // 3*4 (3 vector, 4 bytes per float)
99-
100-
// Check that accessing non-existent data fails.
101-
Result<FreeableBuffer> data2 = data_map->get_data("nonexistent");
102-
EXPECT_EQ(data2.error(), Error::NotFound);
9+
#include <executorch/extension/data_loader/file_data_loader.h>
10+
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
11+
#include <executorch/runtime/core/error.h>
12+
#include <executorch/runtime/core/result.h>
13+
#include <executorch/runtime/executor/method.h>
14+
#include <executorch/runtime/executor/program.h>
15+
#include <executorch/runtime/executor/test/managed_memory_manager.h>
16+
#include <executorch/runtime/platform/runtime.h>
17+
18+
#include <gtest/gtest.h>
19+
20+
using namespace ::testing;
21+
using executorch::extension::FlatTensorDataMap;
22+
using executorch::runtime::DataLoader;
23+
using executorch::runtime::Error;
24+
using executorch::runtime::FreeableBuffer;
25+
using executorch::runtime::Method;
26+
using executorch::runtime::Program;
27+
using executorch::runtime::Result;
28+
using executorch::runtime::testing::ManagedMemoryManager;
29+
using torch::executor::util::FileDataLoader;
30+
31+
constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U;
32+
constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U;
33+
34+
class DataSeparationTest : public ::testing::Test {
35+
protected:
36+
void SetUp() override {
37+
// Since these tests cause ET_LOG to be called, the PAL must be initialized
38+
// first.
39+
executorch::runtime::runtime_init();
40+
41+
// Create data loaders.
42+
Result<FileDataLoader> linear_program_loader =
43+
FileDataLoader::from(std::getenv("ET_MODULE_LINEAR_XNN_PROGRAM_PATH"));
44+
ASSERT_EQ(linear_program_loader.error(), Error::Ok);
45+
linear_program_loader_ = std::make_unique<FileDataLoader>(
46+
std::move(linear_program_loader.get()));
47+
48+
Result<FileDataLoader> linear_data_loader =
49+
FileDataLoader::from(std::getenv("ET_MODULE_LINEAR_XNN_DATA_PATH"));
50+
ASSERT_EQ(linear_data_loader.error(), Error::Ok);
51+
linear_data_loader_ =
52+
std::make_unique<FileDataLoader>(std::move(linear_data_loader.get()));
53+
54+
// Create programs.
55+
Result<Program> linear_program = Program::load(
56+
linear_program_loader_.get(),
57+
Program::Verification::InternalConsistency);
58+
ASSERT_EQ(linear_program.error(), Error::Ok);
59+
linear_program_ =
60+
std::make_unique<Program>(std::move(linear_program.get()));
61+
62+
Result<FlatTensorDataMap> linear_data_map =
63+
FlatTensorDataMap::load(linear_data_loader_.get());
64+
EXPECT_EQ(linear_data_map.error(), Error::Ok);
65+
linear_data_map_ =
66+
std::make_unique<FlatTensorDataMap>(std::move(linear_data_map.get()));
67+
}
68+
69+
private:
70+
std::unique_ptr<FileDataLoader> linear_program_loader_;
71+
std::unique_ptr<FileDataLoader> linear_data_loader_;
72+
73+
protected:
74+
std::unique_ptr<Program> linear_program_;
75+
std::unique_ptr<FlatTensorDataMap> linear_data_map_;
76+
};
77+
78+
TEST_F(DataSeparationTest, TestExternalData) {
79+
FlatTensorDataMap* data_map = linear_data_map_.get();
80+
EXPECT_EQ(data_map->get_num_keys().get(), 2);
81+
82+
Result<const char*> key0 = data_map->get_key(0);
83+
EXPECT_EQ(key0.error(), Error::Ok);
84+
Result<const char*> key1 = data_map->get_key(1);
85+
EXPECT_EQ(key1.error(), Error::Ok);
86+
87+
// Check that accessing keys out of bounds fails.
88+
EXPECT_EQ(data_map->get_key(2).error(), Error::InvalidArgument);
89+
90+
// Linear.weight
91+
Result<FreeableBuffer> data0 = data_map->get_data(key0.get());
92+
EXPECT_EQ(data0.error(), Error::Ok);
93+
EXPECT_EQ(data0.get().size(), 36); // 3*3*4 (3*3 matrix, 4 bytes per float)
94+
95+
// Linear.bias
96+
Result<FreeableBuffer> data1 = data_map->get_data(key1.get());
97+
EXPECT_EQ(data1.error(), Error::Ok);
98+
EXPECT_EQ(data1.get().size(), 12); // 3*4 (3 vector, 4 bytes per float)
99+
100+
// Check that accessing non-existent data fails.
101+
Result<FreeableBuffer> data2 = data_map->get_data("nonexistent");
102+
EXPECT_EQ(data2.error(), Error::NotFound);
103103
}
104104

105-
TEST_F(DataSeparationTest, TestE2E) {
106-
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
107-
Result<Method> method = linear_program_->load_method(
108-
"forward", &mmm.get(), nullptr, linear_data_map_.get());
109-
ASSERT_EQ(method.error(), Error::Ok);
110-
111-
// Can execute the method.
112-
Error err = method->execute();
113-
ASSERT_EQ(err, Error::Ok);
114-
}
105+
TEST_F(DataSeparationTest, TestE2E) {
106+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
107+
Result<Method> method = linear_program_->load_method(
108+
"forward", &mmm.get(), nullptr, linear_data_map_.get());
109+
ASSERT_EQ(method.error(), Error::Ok);
110+
111+
// Can execute the method.
112+
Error err = method->execute();
113+
ASSERT_EQ(err, Error::Ok);
114+
}

examples/models/llama/runner/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ else()
5353
endif()
5454

5555
set(llama_runner_deps executorch_core extension_data_loader extension_module
56-
extension_tensor
56+
extension_tensor extension_flat_tensor
5757
)
5858

5959
target_link_libraries(llama_runner PUBLIC ${llama_runner_deps})

examples/models/llava/runner/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ add_subdirectory(
4141
add_library(llava_runner STATIC ${_llava_runner__srcs})
4242

4343
set(llava_runner_deps executorch_core extension_data_loader extension_llm_runner
44-
extension_module extension_tensor
44+
extension_module extension_tensor extension_flat_tensor
4545
)
4646

4747
target_link_libraries(llava_runner PUBLIC ${llava_runner_deps})

0 commit comments

Comments
 (0)