Skip to content

Commit 68b75cd

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Refine the tests to compare the result with the error code. (#5358)
Summary: Pull Request resolved: #5358 . Reviewed By: kirklandsign Differential Revision: D62659330 fbshipit-source-id: 415dabeead8bd1a3d6353ff88d7783d9daa06e87
1 parent 0aa75e6 commit 68b75cd

File tree

3 files changed

+61
-58
lines changed

3 files changed

+61
-58
lines changed

extension/module/module.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,12 +167,13 @@ runtime::Result<runtime::MethodMeta> Module::method_meta(
167167

168168
runtime::Result<std::vector<runtime::EValue>> Module::execute(
169169
const std::string& method_name,
170-
const std::vector<runtime::EValue>& input) {
170+
const std::vector<runtime::EValue>& input_values) {
171171
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
172172
auto& method = methods_.at(method_name).method;
173173

174-
ET_CHECK_OK_OR_RETURN_ERROR(method->set_inputs(
175-
exec_aten::ArrayRef<runtime::EValue>(input.data(), input.size())));
174+
ET_CHECK_OK_OR_RETURN_ERROR(
175+
method->set_inputs(exec_aten::ArrayRef<runtime::EValue>(
176+
input_values.data(), input_values.size())));
176177
ET_CHECK_OK_OR_RETURN_ERROR(method->execute());
177178

178179
const auto outputs_size = method->outputs_size();

extension/module/module.h

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -165,34 +165,35 @@ class Module {
165165
const std::string& method_name);
166166

167167
/**
168-
* Execute a specific method with the given input and retrieve output.
169-
* Loads the program and method before executing if needed.
168+
* Execute a specific method with the given input values and retrieve the
169+
* output values. Loads the program and method before executing if needed.
170170
*
171171
* @param[in] method_name The name of the method to execute.
172-
* @param[in] input A vector of input values to be passed to the method.
172+
* @param[in] input_values A vector of input values to be passed to the
173+
* method.
173174
*
174175
* @returns A Result object containing either a vector of output values
175176
* from the method or an error to indicate failure.
176177
*/
177178
ET_NODISCARD
178179
runtime::Result<std::vector<runtime::EValue>> execute(
179180
const std::string& method_name,
180-
const std::vector<runtime::EValue>& input);
181+
const std::vector<runtime::EValue>& input_values);
181182

182183
/**
183184
* Execute a specific method with a single input value.
184185
* Loads the program and method before executing if needed.
185186
*
186187
* @param[in] method_name The name of the method to execute.
187-
* @param[in] input A value to be passed to the method.
188+
* @param[in] input_value A value to be passed to the method.
188189
*
189190
* @returns A Result object containing either a vector of output values
190191
* from the method or an error to indicate failure.
191192
*/
192193
ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> execute(
193194
const std::string& method_name,
194-
const runtime::EValue& input) {
195-
return execute(method_name, std::vector<runtime::EValue>{input});
195+
const runtime::EValue& input_value) {
196+
return execute(method_name, std::vector<runtime::EValue>{input_value});
196197
}
197198

198199
/**
@@ -210,19 +211,20 @@ class Module {
210211
}
211212

212213
/**
213-
* Retrieve the output value of a specific method with the given input.
214+
* Retrieve the output value of a specific method with the given input values.
214215
* Loads the program and method before execution if needed.
215216
*
216217
* @param[in] method_name The name of the method to execute.
217-
* @param[in] input A vector of input values to be passed to the method.
218+
* @param[in] input_values A vector of input values to be passed to the
219+
* method.
218220
*
219221
* @returns A Result object containing either the first output value from the
220222
* method or an error to indicate failure.
221223
*/
222224
ET_NODISCARD inline runtime::Result<runtime::EValue> get(
223225
const std::string& method_name,
224-
const std::vector<runtime::EValue>& input) {
225-
auto result = ET_UNWRAP(execute(method_name, input));
226+
const std::vector<runtime::EValue>& input_values) {
227+
auto result = ET_UNWRAP(execute(method_name, input_values));
226228
if (result.empty()) {
227229
return runtime::Error::InvalidArgument;
228230
}
@@ -234,15 +236,15 @@ class Module {
234236
* Loads the program and method before execution if needed.
235237
*
236238
* @param[in] method_name The name of the method to execute.
237-
* @param[in] input A value to be passed to the method.
239+
* @param[in] input_value A value to be passed to the method.
238240
*
239241
* @returns A Result object containing either the first output value from the
240242
* method or an error to indicate failure.
241243
*/
242244
ET_NODISCARD inline runtime::Result<runtime::EValue> get(
243245
const std::string& method_name,
244-
const runtime::EValue& input) {
245-
return get(method_name, std::vector<runtime::EValue>{input});
246+
const runtime::EValue& input_value) {
247+
return get(method_name, std::vector<runtime::EValue>{input_value});
246248
}
247249

248250
/**
@@ -260,31 +262,31 @@ class Module {
260262
}
261263

262264
/**
263-
* Execute the 'forward' method with the given input and retrieve output.
264-
* Loads the program and method before executing if needed.
265+
* Execute the 'forward' method with the given input values and retrieve the
266+
* output values. Loads the program and method before executing if needed.
265267
*
266-
* @param[in] input A vector of input values for the 'forward' method.
268+
* @param[in] input_values A vector of input values for the 'forward' method.
267269
*
268270
* @returns A Result object containing either a vector of output values
269271
* from the 'forward' method or an error to indicate failure.
270272
*/
271273
ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward(
272-
const std::vector<runtime::EValue>& input) {
273-
return execute("forward", input);
274+
const std::vector<runtime::EValue>& input_values) {
275+
return execute("forward", input_values);
274276
}
275277

276278
/**
277279
* Execute the 'forward' method with a single value.
278280
* Loads the program and method before executing if needed.
279281
*
280-
* @param[in] input A value for the 'forward' method.
282+
* @param[in] input_value A value for the 'forward' method.
281283
*
282284
* @returns A Result object containing either a vector of output values
283285
* from the 'forward' method or an error to indicate failure.
284286
*/
285287
ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward(
286-
const runtime::EValue& input) {
287-
return forward(std::vector<runtime::EValue>{input});
288+
const runtime::EValue& input_value) {
289+
return forward(std::vector<runtime::EValue>{input_value});
288290
}
289291

290292
/**

extension/module/test/module_test.cpp

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,15 @@ TEST_F(ModuleTest, TestMethodNames) {
5959
Module module(model_path_);
6060

6161
const auto method_names = module.method_names();
62-
EXPECT_TRUE(method_names.ok());
62+
EXPECT_EQ(method_names.error(), Error::Ok);
6363
EXPECT_EQ(method_names.get(), std::unordered_set<std::string>{"forward"});
6464
}
6565

6666
TEST_F(ModuleTest, TestNonExistentMethodNames) {
6767
Module module("/path/to/nonexistent/file.pte");
6868

6969
const auto method_names = module.method_names();
70-
EXPECT_FALSE(method_names.ok());
70+
EXPECT_NE(method_names.error(), Error::Ok);
7171
}
7272

7373
TEST_F(ModuleTest, TestLoadMethod) {
@@ -93,21 +93,21 @@ TEST_F(ModuleTest, TestMethodMeta) {
9393
Module module(model_path_);
9494

9595
const auto meta = module.method_meta("forward");
96-
EXPECT_TRUE(meta.ok());
96+
EXPECT_EQ(meta.error(), Error::Ok);
9797
EXPECT_STREQ(meta->name(), "forward");
9898
EXPECT_EQ(meta->num_inputs(), 2);
9999
EXPECT_EQ(*(meta->input_tag(0)), Tag::Tensor);
100100
EXPECT_EQ(meta->num_outputs(), 1);
101101
EXPECT_EQ(*(meta->output_tag(0)), Tag::Tensor);
102102

103103
const auto input_meta = meta->input_tensor_meta(0);
104-
EXPECT_TRUE(input_meta.ok());
104+
EXPECT_EQ(input_meta.error(), Error::Ok);
105105
EXPECT_EQ(input_meta->scalar_type(), exec_aten::ScalarType::Float);
106106
EXPECT_EQ(input_meta->sizes().size(), 1);
107107
EXPECT_EQ(input_meta->sizes()[0], 1);
108108

109109
const auto output_meta = meta->output_tensor_meta(0);
110-
EXPECT_TRUE(output_meta.ok());
110+
EXPECT_EQ(output_meta.error(), Error::Ok);
111111
EXPECT_EQ(output_meta->scalar_type(), exec_aten::ScalarType::Float);
112112
EXPECT_EQ(output_meta->sizes().size(), 1);
113113
EXPECT_EQ(output_meta->sizes()[0], 1);
@@ -117,15 +117,15 @@ TEST_F(ModuleTest, TestNonExistentMethodMeta) {
117117
Module module("/path/to/nonexistent/file.pte");
118118

119119
const auto meta = module.method_meta("forward");
120-
EXPECT_FALSE(meta.ok());
120+
EXPECT_NE(meta.error(), Error::Ok);
121121
}
122122

123123
TEST_F(ModuleTest, TestExecute) {
124124
Module module(model_path_);
125-
auto tensor = make_tensor_ptr({1}, {1});
125+
auto tensor = make_tensor_ptr({1.f});
126126

127127
const auto result = module.execute("forward", {tensor, tensor});
128-
EXPECT_TRUE(result.ok());
128+
EXPECT_EQ(result.error(), Error::Ok);
129129

130130
EXPECT_TRUE(module.is_loaded());
131131
EXPECT_TRUE(module.is_method_loaded("forward"));
@@ -141,10 +141,10 @@ TEST_F(ModuleTest, TestExecutePreload) {
141141
const auto error = module.load();
142142
EXPECT_EQ(error, Error::Ok);
143143

144-
auto tensor = make_tensor_ptr({1}, {1});
144+
auto tensor = make_tensor_ptr({1.f});
145145

146146
const auto result = module.execute("forward", {tensor, tensor});
147-
EXPECT_TRUE(result.ok());
147+
EXPECT_EQ(result.error(), Error::Ok);
148148

149149
const auto data = result->at(0).toTensor().const_data_ptr<float>();
150150

@@ -157,10 +157,10 @@ TEST_F(ModuleTest, TestExecutePreload_method) {
157157
const auto error = module.load_method("forward");
158158
EXPECT_EQ(error, Error::Ok);
159159

160-
auto tensor = make_tensor_ptr({1}, {1});
160+
auto tensor = make_tensor_ptr({1.f});
161161

162162
const auto result = module.execute("forward", {tensor, tensor});
163-
EXPECT_TRUE(result.ok());
163+
EXPECT_EQ(result.error(), Error::Ok);
164164

165165
const auto data = result->at(0).toTensor().const_data_ptr<float>();
166166

@@ -176,10 +176,10 @@ TEST_F(ModuleTest, TestExecutePreloadProgramAndMethod) {
176176
const auto load_method_error = module.load_method("forward");
177177
EXPECT_EQ(load_method_error, Error::Ok);
178178

179-
auto tensor = make_tensor_ptr({1}, {1});
179+
auto tensor = make_tensor_ptr({1.f});
180180

181181
const auto result = module.execute("forward", {tensor, tensor});
182-
EXPECT_TRUE(result.ok());
182+
EXPECT_EQ(result.error(), Error::Ok);
183183

184184
const auto data = result->at(0).toTensor().const_data_ptr<float>();
185185

@@ -191,24 +191,24 @@ TEST_F(ModuleTest, TestExecuteOnNonExistent) {
191191

192192
const auto result = module.execute("forward");
193193

194-
EXPECT_FALSE(result.ok());
194+
EXPECT_NE(result.error(), Error::Ok);
195195
}
196196

197197
TEST_F(ModuleTest, TestExecuteOnCurrupted) {
198198
Module module("/dev/null");
199199

200200
const auto result = module.execute("forward");
201201

202-
EXPECT_FALSE(result.ok());
202+
EXPECT_NE(result.error(), Error::Ok);
203203
}
204204

205205
TEST_F(ModuleTest, TestGet) {
206206
Module module(model_path_);
207207

208-
auto tensor = make_tensor_ptr({1}, {1});
208+
auto tensor = make_tensor_ptr({1.f});
209209

210210
const auto result = module.get("forward", {tensor, tensor});
211-
EXPECT_TRUE(result.ok());
211+
EXPECT_EQ(result.error(), Error::Ok);
212212
const auto data = result->toTensor().const_data_ptr<float>();
213213
EXPECT_NEAR(data[0], 2, 1e-5);
214214
}
@@ -218,15 +218,15 @@ TEST_F(ModuleTest, TestForward) {
218218
auto tensor = make_tensor_ptr({21.f});
219219

220220
const auto result = module->forward({tensor, tensor});
221-
EXPECT_TRUE(result.ok());
221+
EXPECT_EQ(result.error(), Error::Ok);
222222

223223
const auto data = result->at(0).toTensor().const_data_ptr<float>();
224224

225225
EXPECT_NEAR(data[0], 42, 1e-5);
226226

227-
auto tensor2 = make_tensor_ptr({1}, {2, 3});
227+
auto tensor2 = make_tensor_ptr({2.f});
228228
const auto result2 = module->forward({tensor2, tensor2});
229-
EXPECT_TRUE(result2.ok());
229+
EXPECT_EQ(result2.error(), Error::Ok);
230230

231231
const auto data2 = result->at(0).toTensor().const_data_ptr<float>();
232232

@@ -238,7 +238,7 @@ TEST_F(ModuleTest, TestForwardWithInvalidInputs) {
238238

239239
const auto result = module.forward(EValue());
240240

241-
EXPECT_FALSE(result.ok());
241+
EXPECT_NE(result.error(), Error::Ok);
242242
}
243243

244244
TEST_F(ModuleTest, TestProgramSharingBetweenModules) {
@@ -253,10 +253,10 @@ TEST_F(ModuleTest, TestProgramSharingBetweenModules) {
253253
EXPECT_TRUE(module2.is_loaded());
254254

255255
auto method_names1 = module1.method_names();
256-
EXPECT_TRUE(method_names1.ok());
256+
EXPECT_EQ(method_names1.error(), Error::Ok);
257257

258258
auto method_names2 = module2.method_names();
259-
EXPECT_TRUE(method_names2.ok());
259+
EXPECT_EQ(method_names2.error(), Error::Ok);
260260
EXPECT_EQ(method_names1.get(), method_names2.get());
261261

262262
auto load_method_error = module1.load_method("forward");
@@ -271,7 +271,7 @@ TEST_F(ModuleTest, TestProgramSharingBetweenModules) {
271271

272272
TEST_F(ModuleTest, TestProgramSharingAndDataLoaderManagement) {
273273
auto loader = FileDataLoader::from(model_path_.c_str());
274-
EXPECT_TRUE(loader.ok());
274+
EXPECT_EQ(loader.error(), Error::Ok);
275275
auto data_loader = std::make_unique<FileDataLoader>(std::move(loader.get()));
276276

277277
auto module1 = std::make_unique<Module>(std::move(data_loader));
@@ -280,29 +280,29 @@ TEST_F(ModuleTest, TestProgramSharingAndDataLoaderManagement) {
280280
EXPECT_EQ(load_error, Error::Ok);
281281
EXPECT_TRUE(module1->is_loaded());
282282

283-
auto tensor = make_tensor_ptr({1}, {1});
283+
auto tensor = make_tensor_ptr({1.f});
284284

285285
const auto result1 = module1->execute("forward", {tensor, tensor});
286-
EXPECT_TRUE(result1.ok());
286+
EXPECT_EQ(result1.error(), Error::Ok);
287287

288288
auto module2 = std::make_unique<Module>(module1->program());
289289

290290
const auto result2 = module2->execute("forward", {tensor, tensor});
291-
EXPECT_TRUE(result2.ok());
291+
EXPECT_EQ(result2.error(), Error::Ok);
292292

293293
module1 = std::make_unique<Module>("/path/to/nonexistent/file.pte");
294294
EXPECT_FALSE(module1->is_loaded());
295295

296296
const auto result3 = module2->execute("forward", {tensor, tensor});
297-
EXPECT_TRUE(result3.ok());
297+
EXPECT_EQ(result3.error(), Error::Ok);
298298
}
299299

300300
TEST_F(ModuleTest, TestProgramPersistenceAndReuseAfterModuleDestruction) {
301301
std::shared_ptr<Program> shared_program;
302302

303303
{
304304
auto loader = FileDataLoader::from(model_path_.c_str());
305-
EXPECT_TRUE(loader.ok());
305+
EXPECT_EQ(loader.error(), Error::Ok);
306306
auto data_loader =
307307
std::make_unique<FileDataLoader>(std::move(loader.get()));
308308
auto* data_loader_ptr = data_loader.get();
@@ -325,10 +325,10 @@ TEST_F(ModuleTest, TestProgramPersistenceAndReuseAfterModuleDestruction) {
325325

326326
EXPECT_EQ(module.program(), shared_program);
327327

328-
auto tensor = make_tensor_ptr({1}, {1});
328+
auto tensor = make_tensor_ptr({1.f});
329329

330330
const auto result = module.execute("forward", {tensor, tensor});
331-
EXPECT_TRUE(result.ok());
331+
EXPECT_EQ(result.error(), Error::Ok);
332332

333333
auto data = result->at(0).toTensor().const_data_ptr<float>();
334334

@@ -355,7 +355,7 @@ TEST_F(ModuleTest, TestConcurrentExecutionWithSharedProgram) {
355355
auto tensor = from_blob((void*)input.data(), {1});
356356

357357
const auto result = module.forward({tensor, tensor});
358-
EXPECT_TRUE(result.ok());
358+
EXPECT_EQ(result.error(), Error::Ok);
359359

360360
const auto data = result->at(0).toTensor().const_data_ptr<float>();
361361
EXPECT_NEAR(data[0], (input[0] * 2), 1e-5);

0 commit comments

Comments
 (0)