Skip to content

Commit 768f5c9

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Simplify setting output. (#5363)
Summary: Pull Request resolved: #5363 . Reviewed By: dbort Differential Revision: D62660981 fbshipit-source-id: a059ad0f67a47b15a8062af5889a4e6af52bf2b3
1 parent 74a56e4 commit 768f5c9

File tree

7 files changed

+72
-32
lines changed

7 files changed

+72
-32
lines changed

examples/qualcomm/oss_scripts/llama2/runner/runner.cpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ Result<exec_aten::Tensor> Runner::run_model_step(
187187
*kv_outputs[j], new_out_addr, kv_outputs[j]->nbytes()) == Error::Ok,
188188
"Failed to set output tensor when updating v_cache");
189189
ET_CHECK_MSG(
190-
module_->set_output_data_ptr(*kv_outputs[j], j + 1) == Error::Ok,
190+
module_->set_output(*kv_outputs[j], j + 1) == Error::Ok,
191191
"Failed to set llama output data pointer");
192192
}
193193

@@ -291,7 +291,7 @@ Error Runner::generate(
291291
sizes,
292292
kv_tensors.back()->scalar_type()));
293293
ET_CHECK_MSG(
294-
module_->set_output_data_ptr(kv_outputs.back(), i + 1) == Error::Ok,
294+
module_->set_output(kv_outputs.back(), i + 1) == Error::Ok,
295295
"Failed to set output tensor for kv cache");
296296
}
297297

@@ -323,8 +323,7 @@ Error Runner::generate(
323323
sizes,
324324
kv_tensors.back()->scalar_type()));
325325
ET_CHECK_MSG(
326-
module_->set_output_data_ptr(kv_outputs.back(), output_index) ==
327-
Error::Ok,
326+
module_->set_output(kv_outputs.back(), output_index) == Error::Ok,
328327
"Failed to set output tensor for llama block");
329328
}
330329

@@ -333,7 +332,7 @@ Error Runner::generate(
333332
logits_data_shape,
334333
ScalarType::Float);
335334
ET_CHECK_MSG(
336-
module_->set_output_data_ptr(affine_logits, 0) == Error::Ok,
335+
module_->set_output(affine_logits) == Error::Ok,
337336
"Failed to set output tensor for affine module - logits");
338337

339338
// Start consuming user's prompts and generating new tokens

examples/qualcomm/qaihub_scripts/llama/runner/io_memory.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ void KVCachedMemory::update_io(
427427
// k, v are placed interleaved
428428
int index = (cache_stride << 1) + (cache_group << 5) + head;
429429
ET_CHECK_MSG(
430-
modules_[shard]->set_output_data_ptr(
430+
modules_[shard]->set_output(
431431
output_tensors[shard][index], index) == Error::Ok,
432432
"failed to set output tensor for module %d's %d'th output "
433433
"while updating kv_cache output tensors",
@@ -450,8 +450,8 @@ void KVCachedMemory::update_io(
450450
for (int shard = 0; shard < output_tensors.size(); shard++) {
451451
for (int index = 0; index < output_tensors[shard].size(); index++) {
452452
ET_CHECK_MSG(
453-
modules_[shard]->set_output_data_ptr(
454-
output_tensors[shard][index], index) == Error::Ok,
453+
modules_[shard]->set_output(output_tensors[shard][index], index) ==
454+
Error::Ok,
455455
"failed to set output tensor for module %d's %d'th output "
456456
"while updating kv_cache output tensors",
457457
shard,

examples/qualcomm/qaihub_scripts/llama/runner/runner.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,7 @@ Error Runner::generate(
177177
output_tensors.emplace_back(io_mem_->get_output_tensors(i));
178178
for (size_t j = 0; j < output_tensors[i].size(); ++j) {
179179
ET_CHECK_MSG(
180-
modules_[i]->set_output_data_ptr(output_tensors[i][j], j) ==
181-
Error::Ok,
180+
modules_[i]->set_output(output_tensors[i][j], j) == Error::Ok,
182181
"failed to set output tensor for module %d's %zu'th output",
183182
i,
184183
j);

examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -373,11 +373,11 @@ Error Runner::generate(std::string prompt) {
373373
uncond_emb_vec.data(),
374374
{1, 77, 1024},
375375
encoder_method_meta.output_tensor_meta(0)->scalar_type());
376-
modules_[0]->set_output_data_ptr(cond_emb_tensor, 0);
376+
modules_[0]->set_output(cond_emb_tensor);
377377
long encoder_start = util::time_in_ms();
378378
auto cond_res = modules_[0]->forward(cond_tokens_tensor);
379379
stats_.text_encoder_execution_time += (util::time_in_ms() - encoder_start);
380-
modules_[0]->set_output_data_ptr(uncond_emb_tensor, 0);
380+
modules_[0]->set_output(uncond_emb_tensor);
381381
encoder_start = util::time_in_ms();
382382
auto uncond_res = modules_[0]->forward(uncond_tokens_tensor);
383383
stats_.text_encoder_execution_time += (util::time_in_ms() - encoder_start);
@@ -462,13 +462,13 @@ Error Runner::generate(std::string prompt) {
462462

463463
stats_.unet_aggregate_post_processing_time +=
464464
(util::time_in_ms() - start_post_process);
465-
modules_[1]->set_output_data_ptr(noise_pred_text_tensor, 0);
465+
modules_[1]->set_output(noise_pred_text_tensor);
466466
long start_unet_execution = util::time_in_ms();
467467
auto cond_res = modules_[1]->forward(
468468
{latent_tensor, time_emb_tensors[step_index], cond_emb_tensor});
469469
stats_.unet_aggregate_execution_time +=
470470
(util::time_in_ms() - start_unet_execution);
471-
modules_[1]->set_output_data_ptr(noise_pred_uncond_tensor, 0);
471+
modules_[1]->set_output(noise_pred_uncond_tensor);
472472
start_unet_execution = util::time_in_ms();
473473
auto uncond_res = modules_[1]->forward(
474474
{latent_tensor,
@@ -519,7 +519,7 @@ Error Runner::generate(std::string prompt) {
519519

520520
quant_tensor(latent, vae_input, vae_input_scale_, vae_input_offset_);
521521

522-
modules_[2]->set_output_data_ptr(output_tensor, 0);
522+
modules_[2]->set_output(output_tensor);
523523
long start_vae_execution = util::time_in_ms();
524524
auto vae_res = modules_[2]->forward(vae_input_tensor);
525525
stats_.vae_execution_time = (util::time_in_ms() - start_vae_execution);

extension/module/module.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,18 @@ runtime::Error Module::set_inputs(
218218
return runtime::Error::Ok;
219219
}
220220

221-
runtime::Error Module::set_output_data_ptr(
221+
runtime::Error Module::set_output(
222+
const std::string& method_name,
222223
runtime::EValue output_value,
223-
size_t output_index,
224-
const std::string& method_name) {
224+
size_t output_index) {
225225
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
226-
auto& output_tensor = output_value.toTensor();
227226
auto& method = methods_.at(method_name).method;
227+
ET_CHECK_OR_RETURN_ERROR(
228+
output_value.isTensor(),
229+
InvalidArgument,
230+
"output type: %zu is not tensor",
231+
(size_t)output_value.tag);
232+
const auto& output_tensor = output_value.toTensor();
228233
return method->set_output_data_ptr(
229234
output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index);
230235
}

extension/module/module.h

+36-13
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,42 @@ class Module {
356356
return set_inputs("forward", input_values);
357357
}
358358

359+
/**
360+
* Sets the output tensor for a specific method.
361+
*
362+
* @param[in] method_name The name of the method.
363+
* @param[in] output_value The EValue containing the Tensor to set as the
364+
* method output.
365+
* @param[in] output_index Zero-based index of the output to set.
366+
*
367+
* @returns An Error to indicate success or failure.
368+
*
369+
* @note Only Tensor outputs are currently supported for setting.
370+
*/
371+
ET_NODISCARD
372+
runtime::Error set_output(
373+
const std::string& method_name,
374+
runtime::EValue output_value,
375+
size_t output_index = 0);
376+
377+
/**
378+
* Sets the output tensor for the "forward" method.
379+
*
380+
* @param[in] output_value The EValue containing the Tensor to set as the
381+
* method output.
382+
* @param[in] output_index Zero-based index of the output to set.
383+
*
384+
* @returns An Error to indicate success or failure.
385+
*
386+
* @note Only Tensor outputs are currently supported for setting.
387+
*/
388+
ET_NODISCARD
389+
inline runtime::Error set_output(
390+
runtime::EValue output_value,
391+
size_t output_index = 0) {
392+
return set_output("forward", std::move(output_value), output_index);
393+
}
394+
359395
/**
360396
* Retrieves the EventTracer instance being used by the Module.
361397
* EventTracer is used for tracking and logging events during the execution
@@ -368,19 +404,6 @@ class Module {
368404
return event_tracer_.get();
369405
}
370406

371-
/**
372-
* Set output data pointer for forward method.
373-
*
374-
* @param[in] output_value A Tensor for the output of 'forward' method.
375-
* @param[in] output_index Index of the output in 'forward' method.
376-
*
377-
* @returns An Error to indicate success or failure of the loading process.
378-
*/
379-
runtime::Error set_output_data_ptr(
380-
runtime::EValue output_value,
381-
size_t output_index,
382-
const std::string& method_name = "forward");
383-
384407
private:
385408
struct MethodHolder {
386409
std::vector<std::vector<uint8_t>> planned_buffers;

extension/module/test/module_test.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -421,3 +421,17 @@ TEST_F(ModuleTest, TestUnsetInputs) {
421421
const auto result = module.forward();
422422
EXPECT_NE(result.error(), Error::Ok);
423423
}
424+
425+
TEST_F(ModuleTest, TestSetOutputInvalidIndex) {
426+
Module module(model_path_);
427+
428+
auto output_tensor = empty({1});
429+
430+
EXPECT_NE(module.set_output(output_tensor, 1), Error::Ok);
431+
}
432+
433+
TEST_F(ModuleTest, TestSetOutputInvalidType) {
434+
Module module(model_path_);
435+
436+
EXPECT_NE(module.set_output(EValue()), Error::Ok);
437+
}

0 commit comments

Comments
 (0)