Skip to content

Commit 9845019

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Let Module tests use Tensor extension and aten mode. (#5298)
Summary: Pull Request resolved: #5298 . Reviewed By: kirklandsign Differential Revision: D62546149 fbshipit-source-id: 4eaee5fde0a6d934bba085bdd0360543f33d51eb
1 parent c080c48 commit 9845019

File tree

3 files changed

+53
-85
lines changed

3 files changed

+53
-85
lines changed

extension/module/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ et_cxx_test(
3232
EXTRA_LIBS
3333
extension_data_loader
3434
extension_module_static
35+
extension_tensor
3536
portable_kernels
3637
portable_ops_lib
3738
)

extension/module/test/module_test.cpp

Lines changed: 25 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <gtest/gtest.h>
1515

1616
#include <executorch/extension/data_loader/file_data_loader.h>
17+
#include <executorch/extension/tensor/tensor.h>
1718

1819
using namespace ::executorch::extension;
1920
using namespace ::executorch::runtime;
@@ -121,17 +122,11 @@ TEST_F(ModuleTest, TestNonExistentMethodMeta) {
121122

122123
TEST_F(ModuleTest, TestExecute) {
123124
Module module(model_path_);
125+
auto tensor = make_tensor_ptr({1}, {1});
124126

125-
std::array<float, 1> input{1};
126-
std::array<int32_t, 1> sizes{1};
127-
exec_aten::TensorImpl tensor(
128-
exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input.data());
129-
130-
const auto result = module.execute(
131-
"forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)});
127+
const auto result = module.execute("forward", {tensor, tensor});
132128
EXPECT_TRUE(result.ok());
133129

134-
EXPECT_TRUE(result.ok());
135130
EXPECT_TRUE(module.is_loaded());
136131
EXPECT_TRUE(module.is_method_loaded("forward"));
137132

@@ -146,13 +141,9 @@ TEST_F(ModuleTest, TestExecutePreload) {
146141
const auto error = module.load();
147142
EXPECT_EQ(error, Error::Ok);
148143

149-
std::array<float, 1> input{1};
150-
std::array<int32_t, 1> sizes{1};
151-
exec_aten::TensorImpl tensor(
152-
exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input.data());
144+
auto tensor = make_tensor_ptr({1}, {1});
153145

154-
const auto result = module.execute(
155-
"forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)});
146+
const auto result = module.execute("forward", {tensor, tensor});
156147
EXPECT_TRUE(result.ok());
157148

158149
const auto data = result->at(0).toTensor().const_data_ptr<float>();
@@ -166,13 +157,9 @@ TEST_F(ModuleTest, TestExecutePreload_method) {
166157
const auto error = module.load_method("forward");
167158
EXPECT_EQ(error, Error::Ok);
168159

169-
std::array<float, 1> input{1};
170-
std::array<int32_t, 1> sizes{1};
171-
exec_aten::TensorImpl tensor(
172-
exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input.data());
160+
auto tensor = make_tensor_ptr({1}, {1});
173161

174-
const auto result = module.execute(
175-
"forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)});
162+
const auto result = module.execute("forward", {tensor, tensor});
176163
EXPECT_TRUE(result.ok());
177164

178165
const auto data = result->at(0).toTensor().const_data_ptr<float>();
@@ -189,13 +176,9 @@ TEST_F(ModuleTest, TestExecutePreloadProgramAndMethod) {
189176
const auto load_method_error = module.load_method("forward");
190177
EXPECT_EQ(load_method_error, Error::Ok);
191178

192-
std::array<float, 1> input{1};
193-
std::array<int32_t, 1> sizes{1};
194-
exec_aten::TensorImpl tensor(
195-
exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input.data());
179+
auto tensor = make_tensor_ptr({1}, {1});
196180

197-
const auto result = module.execute(
198-
"forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)});
181+
const auto result = module.execute("forward", {tensor, tensor});
199182
EXPECT_TRUE(result.ok());
200183

201184
const auto data = result->at(0).toTensor().const_data_ptr<float>();
@@ -222,40 +205,27 @@ TEST_F(ModuleTest, TestExecuteOnCurrupted) {
222205
TEST_F(ModuleTest, TestGet) {
223206
Module module(model_path_);
224207

225-
std::array<float, 1> input{1};
226-
std::array<int32_t, 1> sizes{1};
227-
exec_aten::TensorImpl tensor(
228-
exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input.data());
229-
230-
const auto result = module.get(
231-
"forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)});
208+
auto tensor = make_tensor_ptr({1}, {1});
232209

210+
const auto result = module.get("forward", {tensor, tensor});
233211
EXPECT_TRUE(result.ok());
234212
const auto data = result->toTensor().const_data_ptr<float>();
235213
EXPECT_NEAR(data[0], 2, 1e-5);
236214
}
237215

238216
TEST_F(ModuleTest, TestForward) {
239217
auto module = std::make_unique<Module>(model_path_);
218+
auto tensor = make_tensor_ptr({21.f});
240219

241-
std::array<float, 1> input{1};
242-
std::array<int32_t, 1> sizes{1};
243-
exec_aten::TensorImpl tensor(
244-
exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input.data());
245-
246-
const auto result =
247-
module->forward({exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)});
220+
const auto result = module->forward({tensor, tensor});
248221
EXPECT_TRUE(result.ok());
249222

250223
const auto data = result->at(0).toTensor().const_data_ptr<float>();
251224

252-
EXPECT_NEAR(data[0], 2, 1e-5);
225+
EXPECT_NEAR(data[0], 42, 1e-5);
253226

254-
std::array<float, 2> input2{2, 3};
255-
exec_aten::TensorImpl tensor2(
256-
exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input2.data());
257-
const auto result2 = module->forward(
258-
{exec_aten::Tensor(&tensor2), exec_aten::Tensor(&tensor2)});
227+
auto tensor2 = make_tensor_ptr({1}, {2, 3});
228+
const auto result2 = module->forward({tensor2, tensor2});
259229
EXPECT_TRUE(result2.ok());
260230

261231
const auto data2 = result->at(0).toTensor().const_data_ptr<float>();
@@ -310,26 +280,20 @@ TEST_F(ModuleTest, TestProgramSharingAndDataLoaderManagement) {
310280
EXPECT_EQ(load_error, Error::Ok);
311281
EXPECT_TRUE(module1->is_loaded());
312282

313-
std::array<float, 1> input{1};
314-
std::array<int32_t, 1> sizes{1};
315-
exec_aten::TensorImpl tensor(
316-
exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input.data());
283+
auto tensor = make_tensor_ptr({1}, {1});
317284

318-
auto result1 = module1->execute(
319-
"forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)});
285+
const auto result1 = module1->execute("forward", {tensor, tensor});
320286
EXPECT_TRUE(result1.ok());
321287

322288
auto module2 = std::make_unique<Module>(module1->program());
323289

324-
auto result2 = module2->execute(
325-
"forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)});
290+
const auto result2 = module2->execute("forward", {tensor, tensor});
326291
EXPECT_TRUE(result2.ok());
327292

328293
module1 = std::make_unique<Module>("/path/to/nonexistent/file.pte");
329294
EXPECT_FALSE(module1->is_loaded());
330295

331-
auto result3 = module2->execute(
332-
"forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)});
296+
const auto result3 = module2->execute("forward", {tensor, tensor});
333297
EXPECT_TRUE(result3.ok());
334298
}
335299

@@ -361,13 +325,9 @@ TEST_F(ModuleTest, TestProgramPersistenceAndReuseAfterModuleDestruction) {
361325

362326
EXPECT_EQ(module.program(), shared_program);
363327

364-
std::array<float, 1> input{1};
365-
std::array<int32_t, 1> sizes{1};
366-
exec_aten::TensorImpl tensor(
367-
exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input.data());
328+
auto tensor = make_tensor_ptr({1}, {1});
368329

369-
auto result = module.execute(
370-
"forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)});
330+
const auto result = module.execute("forward", {tensor, tensor});
371331
EXPECT_TRUE(result.ok());
372332

373333
auto data = result->at(0).toTensor().const_data_ptr<float>();
@@ -392,15 +352,9 @@ TEST_F(ModuleTest, TestConcurrentExecutionWithSharedProgram) {
392352
auto thread = [](std::shared_ptr<Program> program,
393353
const std::array<float, 1>& input) {
394354
Module module(program);
395-
std::array<int32_t, 1> sizes{1};
396-
exec_aten::TensorImpl tensor(
397-
exec_aten::ScalarType::Float,
398-
sizes.size(),
399-
sizes.data(),
400-
(void*)input.data());
401-
402-
const auto result = module.forward(
403-
{exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)});
355+
auto tensor = from_blob((void*)input.data(), {1});
356+
357+
const auto result = module.forward({tensor, tensor});
404358
EXPECT_TRUE(result.ok());
405359

406360
const auto data = result->at(0).toTensor().const_data_ptr<float>();

extension/module/test/targets.bzl

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
load(
2+
"@fbsource//tools/build_defs:default_platform_defs.bzl",
3+
"ANDROID",
4+
"CXX",
5+
)
16
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
27

38
def define_common_targets():
@@ -7,20 +12,28 @@ def define_common_targets():
712
TARGETS and BUCK files that call this function.
813
"""
914

10-
runtime.cxx_test(
11-
name = "test",
12-
srcs = [
13-
"module_test.cpp",
14-
],
15-
deps = [
16-
"//executorch/kernels/portable:generated_lib",
17-
"//executorch/extension/data_loader:file_data_loader",
18-
"//executorch/extension/module:module",
19-
],
20-
env = {
21-
"RESOURCES_PATH": "$(location :resources)/resources",
22-
},
23-
)
15+
for aten_mode in (True, False):
16+
aten_suffix = ("_aten" if aten_mode else "")
17+
18+
runtime.cxx_test(
19+
name = "test" + aten_suffix,
20+
srcs = [
21+
"module_test.cpp",
22+
],
23+
deps = [
24+
"//executorch/kernels/portable:generated_lib" + aten_suffix,
25+
"//executorch/extension/data_loader:file_data_loader",
26+
"//executorch/extension/module:module" + aten_suffix,
27+
"//executorch/extension/tensor:tensor" + aten_suffix,
28+
],
29+
env = {
30+
"RESOURCES_PATH": "$(location :resources)/resources",
31+
},
32+
platforms = [CXX, ANDROID], # Cannot bundle resources on Apple platform.
33+
compiler_flags = [
34+
"-Wno-error=deprecated-declarations",
35+
],
36+
)
2437

2538
runtime.filegroup(
2639
name = "resources",

0 commit comments

Comments
 (0)