Skip to content

Simplify setting output. #5363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions examples/qualcomm/oss_scripts/llama2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ Result<exec_aten::Tensor> Runner::run_model_step(
*kv_outputs[j], new_out_addr, kv_outputs[j]->nbytes()) == Error::Ok,
"Failed to set output tensor when updating v_cache");
ET_CHECK_MSG(
module_->set_output_data_ptr(*kv_outputs[j], j + 1) == Error::Ok,
module_->set_output(*kv_outputs[j], j + 1) == Error::Ok,
"Failed to set llama output data pointer");
}

Expand Down Expand Up @@ -291,7 +291,7 @@ Error Runner::generate(
sizes,
kv_tensors.back()->scalar_type()));
ET_CHECK_MSG(
module_->set_output_data_ptr(kv_outputs.back(), i + 1) == Error::Ok,
module_->set_output(kv_outputs.back(), i + 1) == Error::Ok,
"Failed to set output tensor for kv cache");
}

Expand Down Expand Up @@ -323,8 +323,7 @@ Error Runner::generate(
sizes,
kv_tensors.back()->scalar_type()));
ET_CHECK_MSG(
module_->set_output_data_ptr(kv_outputs.back(), output_index) ==
Error::Ok,
module_->set_output(kv_outputs.back(), output_index) == Error::Ok,
"Failed to set output tensor for llama block");
}

Expand All @@ -333,7 +332,7 @@ Error Runner::generate(
logits_data_shape,
ScalarType::Float);
ET_CHECK_MSG(
module_->set_output_data_ptr(affine_logits, 0) == Error::Ok,
module_->set_output(affine_logits) == Error::Ok,
"Failed to set output tensor for affine module - logits");

// Start consuming user's prompts and generating new tokens
Expand Down
6 changes: 3 additions & 3 deletions examples/qualcomm/qaihub_scripts/llama/runner/io_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ void KVCachedMemory::update_io(
// k, v are placed interleaved
int index = (cache_stride << 1) + (cache_group << 5) + head;
ET_CHECK_MSG(
modules_[shard]->set_output_data_ptr(
modules_[shard]->set_output(
output_tensors[shard][index], index) == Error::Ok,
"failed to set output tensor for module %d's %d'th output "
"while updating kv_cache output tensors",
Expand All @@ -450,8 +450,8 @@ void KVCachedMemory::update_io(
for (int shard = 0; shard < output_tensors.size(); shard++) {
for (int index = 0; index < output_tensors[shard].size(); index++) {
ET_CHECK_MSG(
modules_[shard]->set_output_data_ptr(
output_tensors[shard][index], index) == Error::Ok,
modules_[shard]->set_output(output_tensors[shard][index], index) ==
Error::Ok,
"failed to set output tensor for module %d's %d'th output "
"while updating kv_cache output tensors",
shard,
Expand Down
3 changes: 1 addition & 2 deletions examples/qualcomm/qaihub_scripts/llama/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ Error Runner::generate(
output_tensors.emplace_back(io_mem_->get_output_tensors(i));
for (size_t j = 0; j < output_tensors[i].size(); ++j) {
ET_CHECK_MSG(
modules_[i]->set_output_data_ptr(output_tensors[i][j], j) ==
Error::Ok,
modules_[i]->set_output(output_tensors[i][j], j) == Error::Ok,
"failed to set output tensor for module %d's %zu'th output",
i,
j);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,11 +373,11 @@ Error Runner::generate(std::string prompt) {
uncond_emb_vec.data(),
{1, 77, 1024},
encoder_method_meta.output_tensor_meta(0)->scalar_type());
modules_[0]->set_output_data_ptr(cond_emb_tensor, 0);
modules_[0]->set_output(cond_emb_tensor);
long encoder_start = util::time_in_ms();
auto cond_res = modules_[0]->forward(cond_tokens_tensor);
stats_.text_encoder_execution_time += (util::time_in_ms() - encoder_start);
modules_[0]->set_output_data_ptr(uncond_emb_tensor, 0);
modules_[0]->set_output(uncond_emb_tensor);
encoder_start = util::time_in_ms();
auto uncond_res = modules_[0]->forward(uncond_tokens_tensor);
stats_.text_encoder_execution_time += (util::time_in_ms() - encoder_start);
Expand Down Expand Up @@ -462,13 +462,13 @@ Error Runner::generate(std::string prompt) {

stats_.unet_aggregate_post_processing_time +=
(util::time_in_ms() - start_post_process);
modules_[1]->set_output_data_ptr(noise_pred_text_tensor, 0);
modules_[1]->set_output(noise_pred_text_tensor);
long start_unet_execution = util::time_in_ms();
auto cond_res = modules_[1]->forward(
{latent_tensor, time_emb_tensors[step_index], cond_emb_tensor});
stats_.unet_aggregate_execution_time +=
(util::time_in_ms() - start_unet_execution);
modules_[1]->set_output_data_ptr(noise_pred_uncond_tensor, 0);
modules_[1]->set_output(noise_pred_uncond_tensor);
start_unet_execution = util::time_in_ms();
auto uncond_res = modules_[1]->forward(
{latent_tensor,
Expand Down Expand Up @@ -519,7 +519,7 @@ Error Runner::generate(std::string prompt) {

quant_tensor(latent, vae_input, vae_input_scale_, vae_input_offset_);

modules_[2]->set_output_data_ptr(output_tensor, 0);
modules_[2]->set_output(output_tensor);
long start_vae_execution = util::time_in_ms();
auto vae_res = modules_[2]->forward(vae_input_tensor);
stats_.vae_execution_time = (util::time_in_ms() - start_vae_execution);
Expand Down
13 changes: 9 additions & 4 deletions extension/module/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,18 @@ runtime::Error Module::set_inputs(
return runtime::Error::Ok;
}

runtime::Error Module::set_output_data_ptr(
runtime::Error Module::set_output(
const std::string& method_name,
runtime::EValue output_value,
size_t output_index,
const std::string& method_name) {
size_t output_index) {
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
auto& output_tensor = output_value.toTensor();
auto& method = methods_.at(method_name).method;
ET_CHECK_OR_RETURN_ERROR(
output_value.isTensor(),
InvalidArgument,
"output type: %zu is not tensor",
(size_t)output_value.tag);
const auto& output_tensor = output_value.toTensor();
return method->set_output_data_ptr(
output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index);
}
Expand Down
49 changes: 36 additions & 13 deletions extension/module/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,42 @@ class Module {
return set_inputs("forward", input_values);
}

/**
* Sets the output tensor for a specific method.
*
* @param[in] method_name The name of the method.
* @param[in] output_value The EValue containing the Tensor to set as the
* method output.
* @param[in] output_index Zero-based index of the output to set.
*
* @returns An Error to indicate success or failure.
*
* @note Only Tensor outputs are currently supported for setting.
*/
ET_NODISCARD
runtime::Error set_output(
const std::string& method_name,
runtime::EValue output_value,
size_t output_index = 0);

/**
* Sets the output tensor for the "forward" method.
*
* @param[in] output_value The EValue containing the Tensor to set as the
* method output.
* @param[in] output_index Zero-based index of the output to set.
*
* @returns An Error to indicate success or failure.
*
* @note Only Tensor outputs are currently supported for setting.
*/
ET_NODISCARD
inline runtime::Error set_output(
runtime::EValue output_value,
size_t output_index = 0) {
return set_output("forward", std::move(output_value), output_index);
}

/**
* Retrieves the EventTracer instance being used by the Module.
* EventTracer is used for tracking and logging events during the execution
Expand All @@ -368,19 +404,6 @@ class Module {
return event_tracer_.get();
}

/**
* Set output data pointer for forward method.
*
* @param[in] output_value A Tensor for the output of 'forward' method.
* @param[in] output_index Index of the output in 'forward' method.
*
* @returns An Error to indicate success or failure of the loading process.
*/
runtime::Error set_output_data_ptr(
runtime::EValue output_value,
size_t output_index,
const std::string& method_name = "forward");

private:
struct MethodHolder {
std::vector<std::vector<uint8_t>> planned_buffers;
Expand Down
14 changes: 14 additions & 0 deletions extension/module/test/module_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,17 @@ TEST_F(ModuleTest, TestUnsetInputs) {
const auto result = module.forward();
EXPECT_NE(result.error(), Error::Ok);
}

TEST_F(ModuleTest, TestSetOutputInvalidIndex) {
Module module(model_path_);

auto output_tensor = empty({1});

EXPECT_NE(module.set_output(output_tensor, 1), Error::Ok);
}

TEST_F(ModuleTest, TestSetOutputInvalidType) {
Module module(model_path_);

EXPECT_NE(module.set_output(EValue()), Error::Ok);
}
Loading