Skip to content

support qnn runner multi iter run #9071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ DEFINE_string(
model_path,
"kv_llama_qnn.pte",
"Model serialized in flatbuffer format.");

DEFINE_string(
output_path,
"outputs.txt",
Expand All @@ -48,7 +47,6 @@ DEFINE_int32(
seq_len,
128,
"Total number of tokens to generate (prompt + output).");

DEFINE_int32(
eval_mode,
1,
Expand All @@ -59,6 +57,7 @@ DEFINE_string(
kv_updater,
"How to update kv cache. Choose between SmartMask and ShiftPointer",
"SmartMask");
DEFINE_int32(num_iters, 1, "total num of iterations to run.");

int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
Expand All @@ -72,7 +71,8 @@ int main(int argc, char** argv) {
FLAGS_logits_offset,
FLAGS_temperature,
FLAGS_eval_mode,
FLAGS_kv_updater);
FLAGS_kv_updater,
FLAGS_num_iters);
std::vector<char> buf;
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
std::ofstream fout(FLAGS_output_path.c_str());
Expand All @@ -82,11 +82,13 @@ int main(int argc, char** argv) {
}
};
// generate tokens & store inference output
runner.generate(
FLAGS_seq_len,
FLAGS_prompt.c_str(),
FLAGS_system_prompt.c_str(),
callback);
for (int i = 0; i < FLAGS_num_iters; i++) {
runner.generate(
FLAGS_seq_len,
FLAGS_prompt.c_str(),
FLAGS_system_prompt.c_str(),
callback);
}
fout.write(buf.data(), buf.size());
fout.close();
return 0;
Expand Down
19 changes: 19 additions & 0 deletions examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ void ShiftPointerIoMgr::init_io() {
}
}

void ShiftPointerIoMgr::reset_io() {
IO* ptr = static_cast<IO*>(data_ptr_.get());
Copy link
Collaborator

@haowhsu-quic haowhsu-quic Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't actually need to modify the interface of prepare_xx_io. Maybe following snippet is enough:

std::fill(ptr->prefill_attention_mask.begin(), ptr->prefill_attention_mask.end(), 0);
std::fill(ptr->kv_attention_mask.begin(), ptr->kv_attention_mask.end(), 0);

And the following function calls of prepare_xx_io might be omitted, the attention mask will be set correctly when runner invoke fill_xx_toks.
Ditto for smart-mask I think. If you found it work for both versions, please have them both map to one implementation, thank you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

never mind, I tried, this works I will update diff for the logics

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the diff since the smartmask and shifpointer use different data structure for prefill and kv attn so I did not unified the reset_io into one

std::fill(
ptr->prefill_attention_mask.begin(),
ptr->prefill_attention_mask.end(),
0);
std::fill(ptr->kv_attention_mask.begin(), ptr->kv_attention_mask.end(), 0);
}
void ShiftPointerIoMgr::prepare_kv_io(
const std::vector<Result<MethodMeta>>& methods_meta) {
for (int i = 0; i < modules_.size(); ++i) {
Expand Down Expand Up @@ -885,6 +893,17 @@ void SmartMaskIoMgr::init_io() {
ptr->init_io_ptrs(shared_ptr, io_bytes_map);
}

void SmartMaskIoMgr::reset_io() {
IO* ptr = static_cast<IO*>(data_ptr_.get());
int32_t prefill_attn_size = prefill_ar_len_ * context_len_;
int32_t kv_attn_size = kv_ar_len_ * context_len_;
std::fill(
ptr->prefill_attention_mask,
ptr->prefill_attention_mask + prefill_attn_size,
0);
std::fill(ptr->kv_attention_mask, ptr->kv_attention_mask + kv_attn_size, 0);
}

void SmartMaskIoMgr::prepare_kv_io(
const std::vector<Result<MethodMeta>>& methods_meta) {
for (int i = 0; i < modules_.size(); ++i) {
Expand Down
3 changes: 3 additions & 0 deletions examples/qualcomm/oss_scripts/llama/runner/io_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class IoMgrBase {
std::vector<std::shared_ptr<executorch::extension::Module>>& modules);
virtual ~IoMgrBase();
virtual void init_io() = 0;
virtual void reset_io() = 0;
virtual void prepare_prefill_io(
const std::vector<
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
Expand Down Expand Up @@ -97,6 +98,7 @@ class ShiftPointerIoMgr : public IoMgrBase {
const bool use_int64_token);

void init_io() override;
void reset_io() override;
void prepare_prefill_io(
const std::vector<
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
Expand Down Expand Up @@ -199,6 +201,7 @@ class SmartMaskIoMgr : public IoMgrBase {
const bool use_int64_token);

void init_io() override;
void reset_io() override;
void prepare_prefill_io(
const std::vector<
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
Expand Down
11 changes: 7 additions & 4 deletions examples/qualcomm/oss_scripts/llama/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ Runner::Runner(
const int32_t logits_offset,
const float temperature,
const int eval_mode,
const std::string& kv_updater)
const std::string& kv_updater,
const int num_iters)
: n_bos_(1),
n_eos_(1),
tokenizer_path_(tokenizer_path),
Expand All @@ -57,7 +58,8 @@ Runner::Runner(
logits_offset_(logits_offset),
temperature_(temperature),
eval_mode_(static_cast<EvalMode>(eval_mode)),
kv_updater_(kv_updater) {
kv_updater_(kv_updater),
num_iters_(num_iters) {
for (size_t i = 0; i < models_path.size(); ++i) {
modules_.push_back(std::make_shared<Module>(
models_path[i], Module::LoadMode::MmapUseMlockIgnoreErrors));
Expand Down Expand Up @@ -280,7 +282,7 @@ Error Runner::generate(
std::unordered_map<std::string, std::vector<std::vector<Tensor>>>
input_tensors, output_tensors;
std::unordered_map<std::string, std::vector<std::vector<EValue>>> inputs;
if (!is_loaded()) {
if (!is_loaded() || (num_iters_ > 1)) {
stats_.model_load_start_ms = time_in_ms();
ET_CHECK_OK_OR_RETURN_ERROR(load());
for (auto method_name : method_names_) {
Expand Down Expand Up @@ -445,7 +447,8 @@ Error Runner::generate(
if (stats_callback) {
stats_callback(stats_);
}

io_mgr_->reset_io();
prompt_.clear();
return Error::Ok;
}

Expand Down
4 changes: 3 additions & 1 deletion examples/qualcomm/oss_scripts/llama/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class Runner {
const int32_t logits_offset,
const float temperature,
const int eval_mode,
const std::string& kv_updater);
const std::string& kv_updater,
const int num_iters);

struct Stats {
// Scaling factor for timestamps - in this case, we use ms.
Expand Down Expand Up @@ -117,6 +118,7 @@ class Runner {
std::vector<std::string> method_names_;
LlamaVersion llama_version_;
std::string kv_updater_;
int num_iters_;
};

} // namespace example
Loading