Skip to content

Refine the tests to compare the result with the error code. #5358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions extension/module/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,13 @@ runtime::Result<runtime::MethodMeta> Module::method_meta(

runtime::Result<std::vector<runtime::EValue>> Module::execute(
const std::string& method_name,
const std::vector<runtime::EValue>& input) {
const std::vector<runtime::EValue>& input_values) {
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
auto& method = methods_.at(method_name).method;

ET_CHECK_OK_OR_RETURN_ERROR(method->set_inputs(
exec_aten::ArrayRef<runtime::EValue>(input.data(), input.size())));
ET_CHECK_OK_OR_RETURN_ERROR(
method->set_inputs(exec_aten::ArrayRef<runtime::EValue>(
input_values.data(), input_values.size())));
ET_CHECK_OK_OR_RETURN_ERROR(method->execute());

const auto outputs_size = method->outputs_size();
Expand Down
46 changes: 24 additions & 22 deletions extension/module/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,34 +165,35 @@ class Module {
const std::string& method_name);

/**
* Execute a specific method with the given input and retrieve output.
* Loads the program and method before executing if needed.
* Execute a specific method with the given input values and retrieve the
* output values. Loads the program and method before executing if needed.
*
* @param[in] method_name The name of the method to execute.
* @param[in] input A vector of input values to be passed to the method.
* @param[in] input_values A vector of input values to be passed to the
* method.
*
* @returns A Result object containing either a vector of output values
* from the method or an error to indicate failure.
*/
ET_NODISCARD
runtime::Result<std::vector<runtime::EValue>> execute(
const std::string& method_name,
const std::vector<runtime::EValue>& input);
const std::vector<runtime::EValue>& input_values);

/**
* Execute a specific method with a single input value.
* Loads the program and method before executing if needed.
*
* @param[in] method_name The name of the method to execute.
* @param[in] input A value to be passed to the method.
* @param[in] input_value A value to be passed to the method.
*
* @returns A Result object containing either a vector of output values
* from the method or an error to indicate failure.
*/
ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> execute(
const std::string& method_name,
const runtime::EValue& input) {
return execute(method_name, std::vector<runtime::EValue>{input});
const runtime::EValue& input_value) {
return execute(method_name, std::vector<runtime::EValue>{input_value});
}

/**
Expand All @@ -210,19 +211,20 @@ class Module {
}

/**
* Retrieve the output value of a specific method with the given input.
* Retrieve the output value of a specific method with the given input values.
* Loads the program and method before execution if needed.
*
* @param[in] method_name The name of the method to execute.
* @param[in] input A vector of input values to be passed to the method.
* @param[in] input_values A vector of input values to be passed to the
* method.
*
* @returns A Result object containing either the first output value from the
* method or an error to indicate failure.
*/
ET_NODISCARD inline runtime::Result<runtime::EValue> get(
const std::string& method_name,
const std::vector<runtime::EValue>& input) {
auto result = ET_UNWRAP(execute(method_name, input));
const std::vector<runtime::EValue>& input_values) {
auto result = ET_UNWRAP(execute(method_name, input_values));
if (result.empty()) {
return runtime::Error::InvalidArgument;
}
Expand All @@ -234,15 +236,15 @@ class Module {
* Loads the program and method before execution if needed.
*
* @param[in] method_name The name of the method to execute.
* @param[in] input A value to be passed to the method.
* @param[in] input_value A value to be passed to the method.
*
* @returns A Result object containing either the first output value from the
* method or an error to indicate failure.
*/
ET_NODISCARD inline runtime::Result<runtime::EValue> get(
const std::string& method_name,
const runtime::EValue& input) {
return get(method_name, std::vector<runtime::EValue>{input});
const runtime::EValue& input_value) {
return get(method_name, std::vector<runtime::EValue>{input_value});
}

/**
Expand All @@ -260,31 +262,31 @@ class Module {
}

/**
* Execute the 'forward' method with the given input and retrieve output.
* Loads the program and method before executing if needed.
* Execute the 'forward' method with the given input values and retrieve the
* output values. Loads the program and method before executing if needed.
*
* @param[in] input A vector of input values for the 'forward' method.
* @param[in] input_values A vector of input values for the 'forward' method.
*
* @returns A Result object containing either a vector of output values
* from the 'forward' method or an error to indicate failure.
*/
ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward(
const std::vector<runtime::EValue>& input) {
return execute("forward", input);
const std::vector<runtime::EValue>& input_values) {
return execute("forward", input_values);
}

/**
* Execute the 'forward' method with a single value.
* Loads the program and method before executing if needed.
*
* @param[in] input A value for the 'forward' method.
* @param[in] input_value A value for the 'forward' method.
*
* @returns A Result object containing either a vector of output values
* from the 'forward' method or an error to indicate failure.
*/
ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward(
const runtime::EValue& input) {
return forward(std::vector<runtime::EValue>{input});
const runtime::EValue& input_value) {
return forward(std::vector<runtime::EValue>{input_value});
}

/**
Expand Down
66 changes: 33 additions & 33 deletions extension/module/test/module_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ TEST_F(ModuleTest, TestMethodNames) {
Module module(model_path_);

const auto method_names = module.method_names();
EXPECT_TRUE(method_names.ok());
EXPECT_EQ(method_names.error(), Error::Ok);
EXPECT_EQ(method_names.get(), std::unordered_set<std::string>{"forward"});
}

TEST_F(ModuleTest, TestNonExistentMethodNames) {
Module module("/path/to/nonexistent/file.pte");

const auto method_names = module.method_names();
EXPECT_FALSE(method_names.ok());
EXPECT_NE(method_names.error(), Error::Ok);
}

TEST_F(ModuleTest, TestLoadMethod) {
Expand All @@ -93,21 +93,21 @@ TEST_F(ModuleTest, TestMethodMeta) {
Module module(model_path_);

const auto meta = module.method_meta("forward");
EXPECT_TRUE(meta.ok());
EXPECT_EQ(meta.error(), Error::Ok);
EXPECT_STREQ(meta->name(), "forward");
EXPECT_EQ(meta->num_inputs(), 2);
EXPECT_EQ(*(meta->input_tag(0)), Tag::Tensor);
EXPECT_EQ(meta->num_outputs(), 1);
EXPECT_EQ(*(meta->output_tag(0)), Tag::Tensor);

const auto input_meta = meta->input_tensor_meta(0);
EXPECT_TRUE(input_meta.ok());
EXPECT_EQ(input_meta.error(), Error::Ok);
EXPECT_EQ(input_meta->scalar_type(), exec_aten::ScalarType::Float);
EXPECT_EQ(input_meta->sizes().size(), 1);
EXPECT_EQ(input_meta->sizes()[0], 1);

const auto output_meta = meta->output_tensor_meta(0);
EXPECT_TRUE(output_meta.ok());
EXPECT_EQ(output_meta.error(), Error::Ok);
EXPECT_EQ(output_meta->scalar_type(), exec_aten::ScalarType::Float);
EXPECT_EQ(output_meta->sizes().size(), 1);
EXPECT_EQ(output_meta->sizes()[0], 1);
Expand All @@ -117,15 +117,15 @@ TEST_F(ModuleTest, TestNonExistentMethodMeta) {
Module module("/path/to/nonexistent/file.pte");

const auto meta = module.method_meta("forward");
EXPECT_FALSE(meta.ok());
EXPECT_NE(meta.error(), Error::Ok);
}

TEST_F(ModuleTest, TestExecute) {
Module module(model_path_);
auto tensor = make_tensor_ptr({1}, {1});
auto tensor = make_tensor_ptr({1.f});

const auto result = module.execute("forward", {tensor, tensor});
EXPECT_TRUE(result.ok());
EXPECT_EQ(result.error(), Error::Ok);

EXPECT_TRUE(module.is_loaded());
EXPECT_TRUE(module.is_method_loaded("forward"));
Expand All @@ -141,10 +141,10 @@ TEST_F(ModuleTest, TestExecutePreload) {
const auto error = module.load();
EXPECT_EQ(error, Error::Ok);

auto tensor = make_tensor_ptr({1}, {1});
auto tensor = make_tensor_ptr({1.f});

const auto result = module.execute("forward", {tensor, tensor});
EXPECT_TRUE(result.ok());
EXPECT_EQ(result.error(), Error::Ok);

const auto data = result->at(0).toTensor().const_data_ptr<float>();

Expand All @@ -157,10 +157,10 @@ TEST_F(ModuleTest, TestExecutePreload_method) {
const auto error = module.load_method("forward");
EXPECT_EQ(error, Error::Ok);

auto tensor = make_tensor_ptr({1}, {1});
auto tensor = make_tensor_ptr({1.f});

const auto result = module.execute("forward", {tensor, tensor});
EXPECT_TRUE(result.ok());
EXPECT_EQ(result.error(), Error::Ok);

const auto data = result->at(0).toTensor().const_data_ptr<float>();

Expand All @@ -176,10 +176,10 @@ TEST_F(ModuleTest, TestExecutePreloadProgramAndMethod) {
const auto load_method_error = module.load_method("forward");
EXPECT_EQ(load_method_error, Error::Ok);

auto tensor = make_tensor_ptr({1}, {1});
auto tensor = make_tensor_ptr({1.f});

const auto result = module.execute("forward", {tensor, tensor});
EXPECT_TRUE(result.ok());
EXPECT_EQ(result.error(), Error::Ok);

const auto data = result->at(0).toTensor().const_data_ptr<float>();

Expand All @@ -191,24 +191,24 @@ TEST_F(ModuleTest, TestExecuteOnNonExistent) {

const auto result = module.execute("forward");

EXPECT_FALSE(result.ok());
EXPECT_NE(result.error(), Error::Ok);
}

TEST_F(ModuleTest, TestExecuteOnCurrupted) {
Module module("/dev/null");

const auto result = module.execute("forward");

EXPECT_FALSE(result.ok());
EXPECT_NE(result.error(), Error::Ok);
}

TEST_F(ModuleTest, TestGet) {
Module module(model_path_);

auto tensor = make_tensor_ptr({1}, {1});
auto tensor = make_tensor_ptr({1.f});

const auto result = module.get("forward", {tensor, tensor});
EXPECT_TRUE(result.ok());
EXPECT_EQ(result.error(), Error::Ok);
const auto data = result->toTensor().const_data_ptr<float>();
EXPECT_NEAR(data[0], 2, 1e-5);
}
Expand All @@ -218,15 +218,15 @@ TEST_F(ModuleTest, TestForward) {
auto tensor = make_tensor_ptr({21.f});

const auto result = module->forward({tensor, tensor});
EXPECT_TRUE(result.ok());
EXPECT_EQ(result.error(), Error::Ok);

const auto data = result->at(0).toTensor().const_data_ptr<float>();

EXPECT_NEAR(data[0], 42, 1e-5);

auto tensor2 = make_tensor_ptr({1}, {2, 3});
auto tensor2 = make_tensor_ptr({2.f});
const auto result2 = module->forward({tensor2, tensor2});
EXPECT_TRUE(result2.ok());
EXPECT_EQ(result2.error(), Error::Ok);

const auto data2 = result->at(0).toTensor().const_data_ptr<float>();

Expand All @@ -238,7 +238,7 @@ TEST_F(ModuleTest, TestForwardWithInvalidInputs) {

const auto result = module.forward(EValue());

EXPECT_FALSE(result.ok());
EXPECT_NE(result.error(), Error::Ok);
}

TEST_F(ModuleTest, TestProgramSharingBetweenModules) {
Expand All @@ -253,10 +253,10 @@ TEST_F(ModuleTest, TestProgramSharingBetweenModules) {
EXPECT_TRUE(module2.is_loaded());

auto method_names1 = module1.method_names();
EXPECT_TRUE(method_names1.ok());
EXPECT_EQ(method_names1.error(), Error::Ok);

auto method_names2 = module2.method_names();
EXPECT_TRUE(method_names2.ok());
EXPECT_EQ(method_names2.error(), Error::Ok);
EXPECT_EQ(method_names1.get(), method_names2.get());

auto load_method_error = module1.load_method("forward");
Expand All @@ -271,7 +271,7 @@ TEST_F(ModuleTest, TestProgramSharingBetweenModules) {

TEST_F(ModuleTest, TestProgramSharingAndDataLoaderManagement) {
auto loader = FileDataLoader::from(model_path_.c_str());
EXPECT_TRUE(loader.ok());
EXPECT_EQ(loader.error(), Error::Ok);
auto data_loader = std::make_unique<FileDataLoader>(std::move(loader.get()));

auto module1 = std::make_unique<Module>(std::move(data_loader));
Expand All @@ -280,29 +280,29 @@ TEST_F(ModuleTest, TestProgramSharingAndDataLoaderManagement) {
EXPECT_EQ(load_error, Error::Ok);
EXPECT_TRUE(module1->is_loaded());

auto tensor = make_tensor_ptr({1}, {1});
auto tensor = make_tensor_ptr({1.f});

const auto result1 = module1->execute("forward", {tensor, tensor});
EXPECT_TRUE(result1.ok());
EXPECT_EQ(result1.error(), Error::Ok);

auto module2 = std::make_unique<Module>(module1->program());

const auto result2 = module2->execute("forward", {tensor, tensor});
EXPECT_TRUE(result2.ok());
EXPECT_EQ(result2.error(), Error::Ok);

module1 = std::make_unique<Module>("/path/to/nonexistent/file.pte");
EXPECT_FALSE(module1->is_loaded());

const auto result3 = module2->execute("forward", {tensor, tensor});
EXPECT_TRUE(result3.ok());
EXPECT_EQ(result3.error(), Error::Ok);
}

TEST_F(ModuleTest, TestProgramPersistenceAndReuseAfterModuleDestruction) {
std::shared_ptr<Program> shared_program;

{
auto loader = FileDataLoader::from(model_path_.c_str());
EXPECT_TRUE(loader.ok());
EXPECT_EQ(loader.error(), Error::Ok);
auto data_loader =
std::make_unique<FileDataLoader>(std::move(loader.get()));
auto* data_loader_ptr = data_loader.get();
Expand All @@ -325,10 +325,10 @@ TEST_F(ModuleTest, TestProgramPersistenceAndReuseAfterModuleDestruction) {

EXPECT_EQ(module.program(), shared_program);

auto tensor = make_tensor_ptr({1}, {1});
auto tensor = make_tensor_ptr({1.f});

const auto result = module.execute("forward", {tensor, tensor});
EXPECT_TRUE(result.ok());
EXPECT_EQ(result.error(), Error::Ok);

auto data = result->at(0).toTensor().const_data_ptr<float>();

Expand All @@ -355,7 +355,7 @@ TEST_F(ModuleTest, TestConcurrentExecutionWithSharedProgram) {
auto tensor = from_blob((void*)input.data(), {1});

const auto result = module.forward({tensor, tensor});
EXPECT_TRUE(result.ok());
EXPECT_EQ(result.error(), Error::Ok);

const auto data = result->at(0).toTensor().const_data_ptr<float>();
EXPECT_NEAR(data[0], (input[0] * 2), 1e-5);
Expand Down
Loading