Skip to content

Commit a7618c5

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Add API to set inputs independently from execution. (#5356)
Summary: Pull Request resolved: #5356 Strive to match the Method API. Reviewed By: dbort Differential Revision: D62653459 fbshipit-source-id: 86800b4f93b71b3cee1f610d8c51d7db11969818
1 parent 68b75cd commit a7618c5

File tree

3 files changed

+142
-3
lines changed

3 files changed

+142
-3
lines changed

extension/module/module.cpp

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ runtime::Error Module::load_method(
154154
temp_allocator_.get());
155155
method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
156156
method_name.c_str(), method_holder.memory_manager.get(), tracer));
157+
method_holder.inputs.resize(method_holder.method->inputs_size());
157158
methods_.emplace(method_name, std::move(method_holder));
158159
}
159160
return runtime::Error::Ok;
@@ -170,10 +171,19 @@ runtime::Result<std::vector<runtime::EValue>> Module::execute(
170171
const std::vector<runtime::EValue>& input_values) {
171172
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
172173
auto& method = methods_.at(method_name).method;
174+
auto& inputs = methods_.at(method_name).inputs;
173175

174-
ET_CHECK_OK_OR_RETURN_ERROR(
175-
method->set_inputs(exec_aten::ArrayRef<runtime::EValue>(
176-
input_values.data(), input_values.size())));
176+
for (size_t i = 0; i < input_values.size(); ++i) {
177+
if (!input_values[i].isNone()) {
178+
inputs[i] = input_values[i];
179+
}
180+
}
181+
for (size_t i = 0; i < inputs.size(); ++i) {
182+
ET_CHECK_OR_RETURN_ERROR(
183+
!inputs[i].isNone(), InvalidArgument, "input %zu is none", i);
184+
}
185+
ET_CHECK_OK_OR_RETURN_ERROR(method->set_inputs(
186+
exec_aten::ArrayRef<runtime::EValue>(inputs.data(), inputs.size())));
177187
ET_CHECK_OK_OR_RETURN_ERROR(method->execute());
178188

179189
const auto outputs_size = method->outputs_size();
@@ -184,6 +194,30 @@ runtime::Result<std::vector<runtime::EValue>> Module::execute(
184194
return outputs;
185195
}
186196

197+
runtime::Error Module::set_input(
198+
const std::string& method_name,
199+
const runtime::EValue& input_value,
200+
size_t input_index) {
201+
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
202+
methods_.at(method_name).inputs.at(input_index) = input_value;
203+
return runtime::Error::Ok;
204+
}
205+
206+
runtime::Error Module::set_inputs(
207+
const std::string& method_name,
208+
const std::vector<runtime::EValue>& input_values) {
209+
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
210+
auto& inputs = methods_.at(method_name).inputs;
211+
ET_CHECK_OR_RETURN_ERROR(
212+
inputs.size() == input_values.size(),
213+
InvalidArgument,
214+
"input size: %zu does not match method input size: %zu",
215+
input_values.size(),
216+
inputs.size());
217+
inputs = input_values;
218+
return runtime::Error::Ok;
219+
}
220+
187221
runtime::Error Module::set_output_data_ptr(
188222
runtime::EValue output_value,
189223
size_t output_index,

extension/module/module.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,62 @@ class Module {
300300
return forward(std::vector<runtime::EValue>{});
301301
}
302302

303+
/**
304+
* Sets a single input value for a specific method.
305+
*
306+
* @param[in] method_name The name of the method.
307+
* @param[in] input_value The EValue to set as the method input.
308+
* @param[in] input_index Zero-based index of the input to set.
309+
*
310+
* @returns An Error to indicate success or failure.
311+
*/
312+
ET_NODISCARD
313+
runtime::Error set_input(
314+
const std::string& method_name,
315+
const runtime::EValue& input_value,
316+
size_t input_index);
317+
318+
/**
319+
* Sets a single input value for the "forward" method.
320+
*
321+
* @param[in] input_value The EValue to set as the method input.
322+
* @param[in] input_index Zero-based index of the input to set.
323+
*
324+
* @returns An Error to indicate success or failure.
325+
*/
326+
ET_NODISCARD
327+
inline runtime::Error set_input(
328+
const runtime::EValue& input_value,
329+
size_t input_index) {
330+
return set_input("forward", input_value, input_index);
331+
}
332+
333+
/**
334+
* Sets all input values for a specific method.
335+
*
336+
* @param[in] method_name The name of the method.
337+
* @param[in] input_values A vector of EValues to set as the method inputs.
338+
*
339+
* @returns An Error to indicate success or failure.
340+
*/
341+
ET_NODISCARD
342+
runtime::Error set_inputs(
343+
const std::string& method_name,
344+
const std::vector<runtime::EValue>& input_values);
345+
346+
/**
347+
* Sets all input values for the "forward" method.
348+
*
349+
* @param[in] input_values A vector of EValues to set as the method inputs.
350+
*
351+
* @returns An Error to indicate success or failure.
352+
*/
353+
ET_NODISCARD
354+
inline runtime::Error set_inputs(
355+
const std::vector<runtime::EValue>& input_values) {
356+
return set_inputs("forward", input_values);
357+
}
358+
303359
/**
304360
* Retrieves the EventTracer instance being used by the Module.
305361
* EventTracer is used for tracking and logging events during the execution
@@ -332,6 +388,7 @@ class Module {
332388
std::unique_ptr<runtime::HierarchicalAllocator> planned_memory;
333389
std::unique_ptr<runtime::MemoryManager> memory_manager;
334390
std::unique_ptr<runtime::Method> method;
391+
std::vector<runtime::EValue> inputs;
335392
};
336393

337394
private:

extension/module/test/module_test.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,3 +373,51 @@ TEST_F(ModuleTest, TestConcurrentExecutionWithSharedProgram) {
373373
t4.join();
374374
t5.join();
375375
}
376+
377+
TEST_F(ModuleTest, TestSetInputsBeforeExecute) {
378+
Module module(model_path_);
379+
380+
auto tensor1 = make_tensor_ptr({4.f});
381+
auto tensor2 = make_tensor_ptr({5.f});
382+
383+
EXPECT_EQ(module.set_inputs({tensor1, tensor2}), Error::Ok);
384+
385+
const auto result = module.forward();
386+
EXPECT_EQ(result.error(), Error::Ok);
387+
388+
const auto data = result->at(0).toTensor().const_data_ptr<float>();
389+
EXPECT_NEAR(data[0], 9, 1e-5);
390+
}
391+
392+
TEST_F(ModuleTest, TestSetInputCombinedWithExecute) {
393+
Module module(model_path_);
394+
395+
auto tensor1 = make_tensor_ptr({2.f});
396+
auto tensor2 = make_tensor_ptr({3.f});
397+
398+
EXPECT_EQ(module.set_input(tensor2, 1), Error::Ok);
399+
400+
const auto result = module.forward(tensor1);
401+
EXPECT_EQ(result.error(), Error::Ok);
402+
403+
const auto data = result->at(0).toTensor().const_data_ptr<float>();
404+
EXPECT_NEAR(data[0], 5, 1e-5);
405+
}
406+
407+
TEST_F(ModuleTest, TestPartiallySetInputs) {
408+
Module module(model_path_);
409+
410+
auto tensor = make_tensor_ptr({1.f});
411+
412+
EXPECT_EQ(module.set_input(tensor, 0), Error::Ok);
413+
414+
const auto result = module.forward();
415+
EXPECT_NE(result.error(), Error::Ok);
416+
}
417+
418+
TEST_F(ModuleTest, TestUnsetInputs) {
419+
Module module(model_path_);
420+
421+
const auto result = module.forward();
422+
EXPECT_NE(result.error(), Error::Ok);
423+
}

0 commit comments

Comments
 (0)