diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 4e0e70936df..b06fe1279f0 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -167,12 +167,13 @@ runtime::Result Module::method_meta( runtime::Result> Module::execute( const std::string& method_name, - const std::vector& input) { + const std::vector& input_values) { ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name)); auto& method = methods_.at(method_name).method; - ET_CHECK_OK_OR_RETURN_ERROR(method->set_inputs( - exec_aten::ArrayRef(input.data(), input.size()))); + ET_CHECK_OK_OR_RETURN_ERROR( + method->set_inputs(exec_aten::ArrayRef( + input_values.data(), input_values.size()))); ET_CHECK_OK_OR_RETURN_ERROR(method->execute()); const auto outputs_size = method->outputs_size(); diff --git a/extension/module/module.h b/extension/module/module.h index 1a3855c5c01..1197eace331 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -165,11 +165,12 @@ class Module { const std::string& method_name); /** - * Execute a specific method with the given input and retrieve output. - * Loads the program and method before executing if needed. + * Execute a specific method with the given input values and retrieve the + * output values. Loads the program and method before executing if needed. * * @param[in] method_name The name of the method to execute. - * @param[in] input A vector of input values to be passed to the method. + * @param[in] input_values A vector of input values to be passed to the + * method. * * @returns A Result object containing either a vector of output values * from the method or an error to indicate failure. @@ -177,22 +178,22 @@ class Module { ET_NODISCARD runtime::Result> execute( const std::string& method_name, - const std::vector& input); + const std::vector& input_values); /** * Execute a specific method with a single input value. * Loads the program and method before executing if needed. * * @param[in] method_name The name of the method to execute. - * @param[in] input A value to be passed to the method. + * @param[in] input_value A value to be passed to the method. * * @returns A Result object containing either a vector of output values * from the method or an error to indicate failure. */ ET_NODISCARD inline runtime::Result> execute( const std::string& method_name, - const runtime::EValue& input) { - return execute(method_name, std::vector{input}); + const runtime::EValue& input_value) { + return execute(method_name, std::vector{input_value}); } /** @@ -210,19 +211,20 @@ class Module { } /** - * Retrieve the output value of a specific method with the given input. + * Retrieve the output value of a specific method with the given input values. * Loads the program and method before execution if needed. * * @param[in] method_name The name of the method to execute. - * @param[in] input A vector of input values to be passed to the method. + * @param[in] input_values A vector of input values to be passed to the + * method. * * @returns A Result object containing either the first output value from the * method or an error to indicate failure. */ ET_NODISCARD inline runtime::Result get( const std::string& method_name, - const std::vector& input) { - auto result = ET_UNWRAP(execute(method_name, input)); + const std::vector& input_values) { + auto result = ET_UNWRAP(execute(method_name, input_values)); if (result.empty()) { return runtime::Error::InvalidArgument; } @@ -234,15 +236,15 @@ class Module { * Loads the program and method before execution if needed. * * @param[in] method_name The name of the method to execute. - * @param[in] input A value to be passed to the method. + * @param[in] input_value A value to be passed to the method. * * @returns A Result object containing either the first output value from the * method or an error to indicate failure. */ ET_NODISCARD inline runtime::Result get( const std::string& method_name, - const runtime::EValue& input) { - return get(method_name, std::vector{input}); + const runtime::EValue& input_value) { + return get(method_name, std::vector{input_value}); } /** @@ -260,31 +262,31 @@ class Module { } /** - * Execute the 'forward' method with the given input and retrieve output. - * Loads the program and method before executing if needed. + * Execute the 'forward' method with the given input values and retrieve the + * output values. Loads the program and method before executing if needed. * - * @param[in] input A vector of input values for the 'forward' method. + * @param[in] input_values A vector of input values for the 'forward' method. * * @returns A Result object containing either a vector of output values * from the 'forward' method or an error to indicate failure. */ ET_NODISCARD inline runtime::Result> forward( - const std::vector& input) { - return execute("forward", input); + const std::vector& input_values) { + return execute("forward", input_values); } /** * Execute the 'forward' method with a single value. * Loads the program and method before executing if needed. * - * @param[in] input A value for the 'forward' method. + * @param[in] input_value A value for the 'forward' method. * * @returns A Result object containing either a vector of output values * from the 'forward' method or an error to indicate failure. */ ET_NODISCARD inline runtime::Result> forward( - const runtime::EValue& input) { - return forward(std::vector{input}); + const runtime::EValue& input_value) { + return forward(std::vector{input_value}); } /** diff --git a/extension/module/test/module_test.cpp b/extension/module/test/module_test.cpp index f1871d631f1..6f18c8d9cbf 100644 --- a/extension/module/test/module_test.cpp +++ b/extension/module/test/module_test.cpp @@ -59,7 +59,7 @@ TEST_F(ModuleTest, TestMethodNames) { Module module(model_path_); const auto method_names = module.method_names(); - EXPECT_TRUE(method_names.ok()); + EXPECT_EQ(method_names.error(), Error::Ok); EXPECT_EQ(method_names.get(), std::unordered_set{"forward"}); } @@ -67,7 +67,7 @@ TEST_F(ModuleTest, TestNonExistentMethodNames) { Module module("/path/to/nonexistent/file.pte"); const auto method_names = module.method_names(); - EXPECT_FALSE(method_names.ok()); + EXPECT_NE(method_names.error(), Error::Ok); } TEST_F(ModuleTest, TestLoadMethod) { @@ -93,7 +93,7 @@ TEST_F(ModuleTest, TestMethodMeta) { Module module(model_path_); const auto meta = module.method_meta("forward"); - EXPECT_TRUE(meta.ok()); + EXPECT_EQ(meta.error(), Error::Ok); EXPECT_STREQ(meta->name(), "forward"); EXPECT_EQ(meta->num_inputs(), 2); EXPECT_EQ(*(meta->input_tag(0)), Tag::Tensor); @@ -101,13 +101,13 @@ TEST_F(ModuleTest, TestMethodMeta) { EXPECT_EQ(*(meta->output_tag(0)), Tag::Tensor); const auto input_meta = meta->input_tensor_meta(0); - EXPECT_TRUE(input_meta.ok()); + EXPECT_EQ(input_meta.error(), Error::Ok); EXPECT_EQ(input_meta->scalar_type(), exec_aten::ScalarType::Float); EXPECT_EQ(input_meta->sizes().size(), 1); EXPECT_EQ(input_meta->sizes()[0], 1); const auto output_meta = meta->output_tensor_meta(0); - EXPECT_TRUE(output_meta.ok()); + EXPECT_EQ(output_meta.error(), Error::Ok); EXPECT_EQ(output_meta->scalar_type(), exec_aten::ScalarType::Float); EXPECT_EQ(output_meta->sizes().size(), 1); EXPECT_EQ(output_meta->sizes()[0], 1); @@ -117,15 +117,15 @@ TEST_F(ModuleTest, TestNonExistentMethodMeta) { Module module("/path/to/nonexistent/file.pte"); const auto meta = module.method_meta("forward"); - EXPECT_FALSE(meta.ok()); + EXPECT_NE(meta.error(), Error::Ok); } TEST_F(ModuleTest, TestExecute) { Module module(model_path_); - auto tensor = make_tensor_ptr({1}, {1}); + auto tensor = make_tensor_ptr({1.f}); const auto result = module.execute("forward", {tensor, tensor}); - EXPECT_TRUE(result.ok()); + EXPECT_EQ(result.error(), Error::Ok); EXPECT_TRUE(module.is_loaded()); EXPECT_TRUE(module.is_method_loaded("forward")); @@ -141,10 +141,10 @@ TEST_F(ModuleTest, TestExecutePreload) { const auto error = module.load(); EXPECT_EQ(error, Error::Ok); - auto tensor = make_tensor_ptr({1}, {1}); + auto tensor = make_tensor_ptr({1.f}); const auto result = module.execute("forward", {tensor, tensor}); - EXPECT_TRUE(result.ok()); + EXPECT_EQ(result.error(), Error::Ok); const auto data = result->at(0).toTensor().const_data_ptr(); @@ -157,10 +157,10 @@ TEST_F(ModuleTest, TestExecutePreload_method) { const auto error = module.load_method("forward"); EXPECT_EQ(error, Error::Ok); - auto tensor = make_tensor_ptr({1}, {1}); + auto tensor = make_tensor_ptr({1.f}); const auto result = module.execute("forward", {tensor, tensor}); - EXPECT_TRUE(result.ok()); + EXPECT_EQ(result.error(), Error::Ok); const auto data = result->at(0).toTensor().const_data_ptr(); @@ -176,10 +176,10 @@ TEST_F(ModuleTest, TestExecutePreloadProgramAndMethod) { const auto load_method_error = module.load_method("forward"); EXPECT_EQ(load_method_error, Error::Ok); - auto tensor = make_tensor_ptr({1}, {1}); + auto tensor = make_tensor_ptr({1.f}); const auto result = module.execute("forward", {tensor, tensor}); - EXPECT_TRUE(result.ok()); + EXPECT_EQ(result.error(), Error::Ok); const auto data = result->at(0).toTensor().const_data_ptr(); @@ -191,7 +191,7 @@ TEST_F(ModuleTest, TestExecuteOnNonExistent) { const auto result = module.execute("forward"); - EXPECT_FALSE(result.ok()); + EXPECT_NE(result.error(), Error::Ok); } TEST_F(ModuleTest, TestExecuteOnCurrupted) { @@ -199,16 +199,16 @@ TEST_F(ModuleTest, TestExecuteOnCurrupted) { const auto result = module.execute("forward"); - EXPECT_FALSE(result.ok()); + EXPECT_NE(result.error(), Error::Ok); } TEST_F(ModuleTest, TestGet) { Module module(model_path_); - auto tensor = make_tensor_ptr({1}, {1}); + auto tensor = make_tensor_ptr({1.f}); const auto result = module.get("forward", {tensor, tensor}); - EXPECT_TRUE(result.ok()); + EXPECT_EQ(result.error(), Error::Ok); const auto data = result->toTensor().const_data_ptr(); EXPECT_NEAR(data[0], 2, 1e-5); } @@ -218,15 +218,15 @@ TEST_F(ModuleTest, TestForward) { auto tensor = make_tensor_ptr({21.f}); const auto result = module->forward({tensor, tensor}); - EXPECT_TRUE(result.ok()); + EXPECT_EQ(result.error(), Error::Ok); const auto data = result->at(0).toTensor().const_data_ptr(); EXPECT_NEAR(data[0], 42, 1e-5); - auto tensor2 = make_tensor_ptr({1}, {2, 3}); + auto tensor2 = make_tensor_ptr({2.f}); const auto result2 = module->forward({tensor2, tensor2}); - EXPECT_TRUE(result2.ok()); + EXPECT_EQ(result2.error(), Error::Ok); const auto data2 = result->at(0).toTensor().const_data_ptr(); @@ -238,7 +238,7 @@ TEST_F(ModuleTest, TestForwardWithInvalidInputs) { const auto result = module.forward(EValue()); - EXPECT_FALSE(result.ok()); + EXPECT_NE(result.error(), Error::Ok); } TEST_F(ModuleTest, TestProgramSharingBetweenModules) { @@ -253,10 +253,10 @@ TEST_F(ModuleTest, TestProgramSharingBetweenModules) { EXPECT_TRUE(module2.is_loaded()); auto method_names1 = module1.method_names(); - EXPECT_TRUE(method_names1.ok()); + EXPECT_EQ(method_names1.error(), Error::Ok); auto method_names2 = module2.method_names(); - EXPECT_TRUE(method_names2.ok()); + EXPECT_EQ(method_names2.error(), Error::Ok); EXPECT_EQ(method_names1.get(), method_names2.get()); auto load_method_error = module1.load_method("forward"); @@ -271,7 +271,7 @@ TEST_F(ModuleTest, TestProgramSharingBetweenModules) { TEST_F(ModuleTest, TestProgramSharingAndDataLoaderManagement) { auto loader = FileDataLoader::from(model_path_.c_str()); - EXPECT_TRUE(loader.ok()); + EXPECT_EQ(loader.error(), Error::Ok); auto data_loader = std::make_unique(std::move(loader.get())); auto module1 = std::make_unique(std::move(data_loader)); @@ -280,21 +280,21 @@ TEST_F(ModuleTest, TestProgramSharingAndDataLoaderManagement) { EXPECT_EQ(load_error, Error::Ok); EXPECT_TRUE(module1->is_loaded()); - auto tensor = make_tensor_ptr({1}, {1}); + auto tensor = make_tensor_ptr({1.f}); const auto result1 = module1->execute("forward", {tensor, tensor}); - EXPECT_TRUE(result1.ok()); + EXPECT_EQ(result1.error(), Error::Ok); auto module2 = std::make_unique(module1->program()); const auto result2 = module2->execute("forward", {tensor, tensor}); - EXPECT_TRUE(result2.ok()); + EXPECT_EQ(result2.error(), Error::Ok); module1 = std::make_unique("/path/to/nonexistent/file.pte"); EXPECT_FALSE(module1->is_loaded()); const auto result3 = module2->execute("forward", {tensor, tensor}); - EXPECT_TRUE(result3.ok()); + EXPECT_EQ(result3.error(), Error::Ok); } TEST_F(ModuleTest, TestProgramPersistenceAndReuseAfterModuleDestruction) { @@ -302,7 +302,7 @@ TEST_F(ModuleTest, TestProgramPersistenceAndReuseAfterModuleDestruction) { { auto loader = FileDataLoader::from(model_path_.c_str()); - EXPECT_TRUE(loader.ok()); + EXPECT_EQ(loader.error(), Error::Ok); auto data_loader = std::make_unique(std::move(loader.get())); auto* data_loader_ptr = data_loader.get(); @@ -325,10 +325,10 @@ TEST_F(ModuleTest, TestProgramPersistenceAndReuseAfterModuleDestruction) { EXPECT_EQ(module.program(), shared_program); - auto tensor = make_tensor_ptr({1}, {1}); + auto tensor = make_tensor_ptr({1.f}); const auto result = module.execute("forward", {tensor, tensor}); - EXPECT_TRUE(result.ok()); + EXPECT_EQ(result.error(), Error::Ok); auto data = result->at(0).toTensor().const_data_ptr(); @@ -355,7 +355,7 @@ TEST_F(ModuleTest, TestConcurrentExecutionWithSharedProgram) { auto tensor = from_blob((void*)input.data(), {1}); const auto result = module.forward({tensor, tensor}); - EXPECT_TRUE(result.ok()); + EXPECT_EQ(result.error(), Error::Ok); const auto data = result->at(0).toTensor().const_data_ptr(); EXPECT_NEAR(data[0], (input[0] * 2), 1e-5);