Skip to content

Commit 09f1baf

Browse files
billmguofacebook-github-bot
authored andcommitted
support qnn runner multi iter run
Summary: support qnn runner multi iter run Differential Revision: D70842764
1 parent 56c94c2 commit 09f1baf

File tree

5 files changed

+172
-14
lines changed

5 files changed

+172
-14
lines changed

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ DEFINE_string(
2525
model_path,
2626
"kv_llama_qnn.pte",
2727
"Model serialized in flatbuffer format.");
28-
2928
DEFINE_string(
3029
output_path,
3130
"outputs.txt",
@@ -48,7 +47,6 @@ DEFINE_int32(
4847
seq_len,
4948
128,
5049
"Total number of tokens to generate (prompt + output).");
51-
5250
DEFINE_int32(
5351
eval_mode,
5452
1,
@@ -59,6 +57,7 @@ DEFINE_string(
5957
kv_updater,
6058
"How to update kv cache. Choose between SmartMask and ShiftPointer",
6159
"SmartMask");
60+
DEFINE_int32(num_iters, 1, "total num of iterations to run.");
6261

6362
int main(int argc, char** argv) {
6463
gflags::ParseCommandLineFlags(&argc, &argv, true);
@@ -72,7 +71,8 @@ int main(int argc, char** argv) {
7271
FLAGS_logits_offset,
7372
FLAGS_temperature,
7473
FLAGS_eval_mode,
75-
FLAGS_kv_updater);
74+
FLAGS_kv_updater,
75+
FLAGS_num_iters);
7676
std::vector<char> buf;
7777
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
7878
std::ofstream fout(FLAGS_output_path.c_str());
@@ -82,11 +82,13 @@ int main(int argc, char** argv) {
8282
}
8383
};
8484
// generate tokens & store inference output
85-
runner.generate(
86-
FLAGS_seq_len,
87-
FLAGS_prompt.c_str(),
88-
FLAGS_system_prompt.c_str(),
89-
callback);
85+
for (int i = 0; i < FLAGS_num_iters; i++) {
86+
runner.generate(
87+
FLAGS_seq_len,
88+
FLAGS_prompt.c_str(),
89+
FLAGS_system_prompt.c_str(),
90+
callback);
91+
}
9092
fout.write(buf.data(), buf.size());
9193
fout.close();
9294
return 0;

examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,101 @@ void ShiftPointerIoMgr::init_io() {
167167
break;
168168
}
169169
}
170+
void ShiftPointerIoMgr::reset_io(
171+
const std::vector<Result<MethodMeta>>& prefill_methods_meta,
172+
const std::vector<Result<MethodMeta>>& kv_methods_meta) {
173+
IO* ptr = static_cast<IO*>(data_ptr_.get());
174+
std::memset(ptr, 0, sizeof(IO));
175+
int32_t k_in_size = (head_dim_ + 1) * kv_cache_len_;
176+
int32_t max_ar_len = std::max(kv_ar_len_, prefill_ar_len_);
177+
178+
int32_t v_cache_size = (num_heads_ + 1) * context_len_ * head_dim_;
179+
int32_t k_cache_out_size = num_heads_ * max_ar_len * head_dim_;
180+
181+
ptr->k_cache_out.clear();
182+
ptr->v_cache.clear();
183+
// Optionally, reserve space again if you plan to refill them
184+
ptr->k_cache_out.reserve(num_layers_);
185+
ptr->v_cache.reserve(num_layers_);
186+
// Refill the vectors if needed
187+
for (int layer = 0; layer < num_layers_; layer++) {
188+
ptr->k_cache_out.emplace_back(std::vector<uint8_t>(k_cache_out_size));
189+
ptr->v_cache.emplace_back(std::vector<uint8_t>(v_cache_size));
190+
}
191+
192+
auto reset_kv = [&]() {
193+
ptr->kv_logits.clear();
194+
ptr->kv_logits.resize(kv_ar_len_ * vocab_size_);
195+
196+
ptr->kv_attention_mask.clear();
197+
ptr->kv_attention_mask.resize((kv_ar_len_ * context_len_), 0);
198+
199+
ptr->k_cache.clear();
200+
ptr->k_cache.reserve(num_layers_);
201+
for (int layer = 0; layer < num_layers_; layer++) {
202+
ptr->k_cache.emplace_back();
203+
ptr->k_cache[layer].reserve(num_heads_);
204+
for (int head = 0; head < num_heads_; head++) {
205+
ptr->k_cache[layer].emplace_back(std::vector<uint8_t>(k_in_size));
206+
}
207+
}
208+
};
209+
210+
auto reset_prefill = [&]() {
211+
ptr->prefill_input_toks.clear();
212+
ptr->prefill_input_toks.resize(prefill_ar_len_, 0);
213+
214+
ptr->prefill_input_pos.clear();
215+
ptr->prefill_input_pos.resize(prefill_ar_len_, 0);
216+
217+
ptr->prefill_attention_mask.clear();
218+
ptr->prefill_attention_mask.resize((prefill_ar_len_ * context_len_), 0);
170219

220+
ptr->prefill_logits.clear();
221+
ptr->prefill_logits.resize(prefill_ar_len_ * vocab_size_);
222+
};
223+
switch (eval_mode_) {
224+
case EvalMode::kKVCached:
225+
reset_kv();
226+
break;
227+
case EvalMode::kHybrid:
228+
reset_prefill();
229+
reset_kv();
230+
break;
231+
default:
232+
break;
233+
}
234+
235+
input_tensors_[kv_forward_name_].clear();
236+
input_tensors_[kv_forward_name_].resize(modules_.size());
237+
output_tensors_[kv_forward_name_].clear();
238+
output_tensors_[kv_forward_name_].resize(modules_.size());
239+
k_cache_in_[kv_forward_name_].clear();
240+
v_cache_in_[kv_forward_name_].clear();
241+
k_cache_out_[kv_forward_name_].clear();
242+
v_cache_out_[kv_forward_name_].clear();
243+
input_tensors_[prefill_forward_name_].clear();
244+
input_tensors_[prefill_forward_name_].resize(modules_.size());
245+
output_tensors_[prefill_forward_name_].clear();
246+
output_tensors_[prefill_forward_name_].resize(modules_.size());
247+
k_cache_in_[prefill_forward_name_].clear();
248+
v_cache_in_[prefill_forward_name_].clear();
249+
k_cache_out_[prefill_forward_name_].clear();
250+
v_cache_out_[prefill_forward_name_].clear();
251+
252+
switch (eval_mode_) {
253+
case EvalMode::kKVCached:
254+
prepare_kv_io(kv_methods_meta);
255+
break;
256+
case EvalMode::kHybrid:
257+
prepare_prefill_io(prefill_methods_meta);
258+
prepare_kv_io(kv_methods_meta);
259+
break;
260+
default:
261+
ET_CHECK_MSG(false, "unsupported mode");
262+
break;
263+
}
264+
}
171265
void ShiftPointerIoMgr::prepare_kv_io(
172266
const std::vector<Result<MethodMeta>>& methods_meta) {
173267
for (int i = 0; i < modules_.size(); ++i) {
@@ -179,7 +273,6 @@ void ShiftPointerIoMgr::prepare_kv_io(
179273

180274
ET_CHECK_MSG(!(kv_forward_name_.empty()), "kv forward name is empty");
181275
IO* ptr = static_cast<IO*>(data_ptr_.get());
182-
183276
// [I]: input_tokens
184277
Result<TensorInfo> kv_input_toks = methods_meta[0]->input_tensor_meta(0);
185278
kv_input_toks_ = std::make_unique<TensorImpl>(
@@ -406,7 +499,6 @@ void ShiftPointerIoMgr::prepare_prefill_io(
406499
const_cast<TensorImpl::DimOrderType*>(logits->dim_order().data()));
407500
output_tensors_[prefill_forward_name_][modules_.size() - 1].push_back(
408501
prefill_logits_.get());
409-
410502
// [O] kv_cache
411503
int index = 1;
412504
// In hybrid mode, we use kv mode cache len for v stride since we want to
@@ -885,6 +977,44 @@ void SmartMaskIoMgr::init_io() {
885977
ptr->init_io_ptrs(shared_ptr, io_bytes_map);
886978
}
887979

980+
void SmartMaskIoMgr::reset_io(
981+
const std::vector<Result<MethodMeta>>& prefill_methods_meta,
982+
const std::vector<Result<MethodMeta>>& kv_methods_meta) {
983+
init_io();
984+
input_tensors_[kv_forward_name_].clear();
985+
input_tensors_[kv_forward_name_].resize(modules_.size());
986+
output_tensors_[kv_forward_name_].clear();
987+
output_tensors_[kv_forward_name_].resize(modules_.size());
988+
989+
k_cache_in_[kv_forward_name_].clear();
990+
v_cache_in_[kv_forward_name_].clear();
991+
k_cache_out_[kv_forward_name_].clear();
992+
v_cache_out_[kv_forward_name_].clear();
993+
994+
input_tensors_[prefill_forward_name_].clear();
995+
input_tensors_[prefill_forward_name_].resize(modules_.size());
996+
output_tensors_[prefill_forward_name_].clear();
997+
output_tensors_[prefill_forward_name_].resize(modules_.size());
998+
999+
k_cache_in_[prefill_forward_name_].clear();
1000+
v_cache_in_[prefill_forward_name_].clear();
1001+
k_cache_out_[prefill_forward_name_].clear();
1002+
v_cache_out_[prefill_forward_name_].clear();
1003+
1004+
switch (eval_mode_) {
1005+
case EvalMode::kKVCached:
1006+
prepare_kv_io(prefill_methods_meta);
1007+
break;
1008+
case EvalMode::kHybrid:
1009+
prepare_prefill_io(prefill_methods_meta);
1010+
prepare_kv_io(kv_methods_meta);
1011+
break;
1012+
default:
1013+
ET_CHECK_MSG(false, "unsupported mode");
1014+
break;
1015+
}
1016+
}
1017+
8881018
void SmartMaskIoMgr::prepare_kv_io(
8891019
const std::vector<Result<MethodMeta>>& methods_meta) {
8901020
for (int i = 0; i < modules_.size(); ++i) {

examples/qualcomm/oss_scripts/llama/runner/io_manager.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ class IoMgrBase {
3333
std::vector<std::shared_ptr<executorch::extension::Module>>& modules);
3434
virtual ~IoMgrBase();
3535
virtual void init_io() = 0;
36+
virtual void reset_io(
37+
const std::vector<executorch::runtime::Result<
38+
executorch::runtime::MethodMeta>>& prefill_methods_meta,
39+
const std::vector<
40+
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
41+
kv_methods_meta) = 0;
3642
virtual void prepare_prefill_io(
3743
const std::vector<
3844
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
@@ -97,6 +103,12 @@ class ShiftPointerIoMgr : public IoMgrBase {
97103
const bool use_int64_token);
98104

99105
void init_io() override;
106+
void reset_io(
107+
const std::vector<executorch::runtime::Result<
108+
executorch::runtime::MethodMeta>>& prefill_methods_meta,
109+
const std::vector<
110+
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
111+
kv_methods_meta) override;
100112
void prepare_prefill_io(
101113
const std::vector<
102114
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
@@ -199,6 +211,12 @@ class SmartMaskIoMgr : public IoMgrBase {
199211
const bool use_int64_token);
200212

201213
void init_io() override;
214+
void reset_io(
215+
const std::vector<executorch::runtime::Result<
216+
executorch::runtime::MethodMeta>>& prefill_methods_meta,
217+
const std::vector<
218+
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
219+
kv_methods_meta) override;
202220
void prepare_prefill_io(
203221
const std::vector<
204222
executorch::runtime::Result<executorch::runtime::MethodMeta>>&

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ Runner::Runner(
4848
const int32_t logits_offset,
4949
const float temperature,
5050
const int eval_mode,
51-
const std::string& kv_updater)
51+
const std::string& kv_updater,
52+
const int num_iters)
5253
: n_bos_(1),
5354
n_eos_(1),
5455
tokenizer_path_(tokenizer_path),
@@ -57,7 +58,8 @@ Runner::Runner(
5758
logits_offset_(logits_offset),
5859
temperature_(temperature),
5960
eval_mode_(static_cast<EvalMode>(eval_mode)),
60-
kv_updater_(kv_updater) {
61+
kv_updater_(kv_updater),
62+
num_iters_(num_iters) {
6163
for (size_t i = 0; i < models_path.size(); ++i) {
6264
modules_.push_back(std::make_shared<Module>(
6365
models_path[i], Module::LoadMode::MmapUseMlockIgnoreErrors));
@@ -280,7 +282,7 @@ Error Runner::generate(
280282
std::unordered_map<std::string, std::vector<std::vector<Tensor>>>
281283
input_tensors, output_tensors;
282284
std::unordered_map<std::string, std::vector<std::vector<EValue>>> inputs;
283-
if (!is_loaded()) {
285+
if (!is_loaded() || num_iters_ > 1) {
284286
stats_.model_load_start_ms = time_in_ms();
285287
ET_CHECK_OK_OR_RETURN_ERROR(load());
286288
for (auto method_name : method_names_) {
@@ -445,7 +447,11 @@ Error Runner::generate(
445447
if (stats_callback) {
446448
stats_callback(stats_);
447449
}
450+
io_mgr_->reset_io(
451+
get_methods_meta(prefill_forward_name_),
452+
get_methods_meta(kv_forward_name_));
448453

454+
prompt_.clear();
449455
return Error::Ok;
450456
}
451457

examples/qualcomm/oss_scripts/llama/runner/runner.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ class Runner {
3434
const int32_t logits_offset,
3535
const float temperature,
3636
const int eval_mode,
37-
const std::string& kv_updater);
37+
const std::string& kv_updater,
38+
const int num_iters);
3839

3940
struct Stats {
4041
// Scaling factor for timestamps - in this case, we use ms.
@@ -117,6 +118,7 @@ class Runner {
117118
std::vector<std::string> method_names_;
118119
LlamaVersion llama_version_;
119120
std::string kv_updater_;
121+
int num_iters_;
120122
};
121123

122124
} // namespace example

0 commit comments

Comments
 (0)