@@ -167,7 +167,101 @@ void ShiftPointerIoMgr::init_io() {
167
167
break ;
168
168
}
169
169
}
170
+ void ShiftPointerIoMgr::reset_io (
171
+ const std::vector<Result<MethodMeta>>& prefill_methods_meta,
172
+ const std::vector<Result<MethodMeta>>& kv_methods_meta) {
173
+ IO* ptr = static_cast <IO*>(data_ptr_.get ());
174
+ std::memset (ptr, 0 , sizeof (IO));
175
+ int32_t k_in_size = (head_dim_ + 1 ) * kv_cache_len_;
176
+ int32_t max_ar_len = std::max (kv_ar_len_, prefill_ar_len_);
177
+
178
+ int32_t v_cache_size = (num_heads_ + 1 ) * context_len_ * head_dim_;
179
+ int32_t k_cache_out_size = num_heads_ * max_ar_len * head_dim_;
180
+
181
+ ptr->k_cache_out .clear ();
182
+ ptr->v_cache .clear ();
183
+ // Optionally, reserve space again if you plan to refill them
184
+ ptr->k_cache_out .reserve (num_layers_);
185
+ ptr->v_cache .reserve (num_layers_);
186
+ // Refill the vectors if needed
187
+ for (int layer = 0 ; layer < num_layers_; layer++) {
188
+ ptr->k_cache_out .emplace_back (std::vector<uint8_t >(k_cache_out_size));
189
+ ptr->v_cache .emplace_back (std::vector<uint8_t >(v_cache_size));
190
+ }
191
+
192
+ auto reset_kv = [&]() {
193
+ ptr->kv_logits .clear ();
194
+ ptr->kv_logits .resize (kv_ar_len_ * vocab_size_);
195
+
196
+ ptr->kv_attention_mask .clear ();
197
+ ptr->kv_attention_mask .resize ((kv_ar_len_ * context_len_), 0 );
198
+
199
+ ptr->k_cache .clear ();
200
+ ptr->k_cache .reserve (num_layers_);
201
+ for (int layer = 0 ; layer < num_layers_; layer++) {
202
+ ptr->k_cache .emplace_back ();
203
+ ptr->k_cache [layer].reserve (num_heads_);
204
+ for (int head = 0 ; head < num_heads_; head++) {
205
+ ptr->k_cache [layer].emplace_back (std::vector<uint8_t >(k_in_size));
206
+ }
207
+ }
208
+ };
209
+
210
+ auto reset_prefill = [&]() {
211
+ ptr->prefill_input_toks .clear ();
212
+ ptr->prefill_input_toks .resize (prefill_ar_len_, 0 );
213
+
214
+ ptr->prefill_input_pos .clear ();
215
+ ptr->prefill_input_pos .resize (prefill_ar_len_, 0 );
216
+
217
+ ptr->prefill_attention_mask .clear ();
218
+ ptr->prefill_attention_mask .resize ((prefill_ar_len_ * context_len_), 0 );
170
219
220
+ ptr->prefill_logits .clear ();
221
+ ptr->prefill_logits .resize (prefill_ar_len_ * vocab_size_);
222
+ };
223
+ switch (eval_mode_) {
224
+ case EvalMode::kKVCached :
225
+ reset_kv ();
226
+ break ;
227
+ case EvalMode::kHybrid :
228
+ reset_prefill ();
229
+ reset_kv ();
230
+ break ;
231
+ default :
232
+ break ;
233
+ }
234
+
235
+ input_tensors_[kv_forward_name_].clear ();
236
+ input_tensors_[kv_forward_name_].resize (modules_.size ());
237
+ output_tensors_[kv_forward_name_].clear ();
238
+ output_tensors_[kv_forward_name_].resize (modules_.size ());
239
+ k_cache_in_[kv_forward_name_].clear ();
240
+ v_cache_in_[kv_forward_name_].clear ();
241
+ k_cache_out_[kv_forward_name_].clear ();
242
+ v_cache_out_[kv_forward_name_].clear ();
243
+ input_tensors_[prefill_forward_name_].clear ();
244
+ input_tensors_[prefill_forward_name_].resize (modules_.size ());
245
+ output_tensors_[prefill_forward_name_].clear ();
246
+ output_tensors_[prefill_forward_name_].resize (modules_.size ());
247
+ k_cache_in_[prefill_forward_name_].clear ();
248
+ v_cache_in_[prefill_forward_name_].clear ();
249
+ k_cache_out_[prefill_forward_name_].clear ();
250
+ v_cache_out_[prefill_forward_name_].clear ();
251
+
252
+ switch (eval_mode_) {
253
+ case EvalMode::kKVCached :
254
+ prepare_kv_io (kv_methods_meta);
255
+ break ;
256
+ case EvalMode::kHybrid :
257
+ prepare_prefill_io (prefill_methods_meta);
258
+ prepare_kv_io (kv_methods_meta);
259
+ break ;
260
+ default :
261
+ ET_CHECK_MSG (false , " unsupported mode" );
262
+ break ;
263
+ }
264
+ }
171
265
void ShiftPointerIoMgr::prepare_kv_io (
172
266
const std::vector<Result<MethodMeta>>& methods_meta) {
173
267
for (int i = 0 ; i < modules_.size (); ++i) {
@@ -179,7 +273,6 @@ void ShiftPointerIoMgr::prepare_kv_io(
179
273
180
274
ET_CHECK_MSG (!(kv_forward_name_.empty ()), " kv forward name is empty" );
181
275
IO* ptr = static_cast <IO*>(data_ptr_.get ());
182
-
183
276
// [I]: input_tokens
184
277
Result<TensorInfo> kv_input_toks = methods_meta[0 ]->input_tensor_meta (0 );
185
278
kv_input_toks_ = std::make_unique<TensorImpl>(
@@ -406,7 +499,6 @@ void ShiftPointerIoMgr::prepare_prefill_io(
406
499
const_cast <TensorImpl::DimOrderType*>(logits->dim_order ().data ()));
407
500
output_tensors_[prefill_forward_name_][modules_.size () - 1 ].push_back (
408
501
prefill_logits_.get ());
409
-
410
502
// [O] kv_cache
411
503
int index = 1 ;
412
504
// In hybrid mode, we use kv mode cache len for v stride since we want to
@@ -885,6 +977,44 @@ void SmartMaskIoMgr::init_io() {
885
977
ptr->init_io_ptrs (shared_ptr, io_bytes_map);
886
978
}
887
979
980
+ void SmartMaskIoMgr::reset_io (
981
+ const std::vector<Result<MethodMeta>>& prefill_methods_meta,
982
+ const std::vector<Result<MethodMeta>>& kv_methods_meta) {
983
+ init_io ();
984
+ input_tensors_[kv_forward_name_].clear ();
985
+ input_tensors_[kv_forward_name_].resize (modules_.size ());
986
+ output_tensors_[kv_forward_name_].clear ();
987
+ output_tensors_[kv_forward_name_].resize (modules_.size ());
988
+
989
+ k_cache_in_[kv_forward_name_].clear ();
990
+ v_cache_in_[kv_forward_name_].clear ();
991
+ k_cache_out_[kv_forward_name_].clear ();
992
+ v_cache_out_[kv_forward_name_].clear ();
993
+
994
+ input_tensors_[prefill_forward_name_].clear ();
995
+ input_tensors_[prefill_forward_name_].resize (modules_.size ());
996
+ output_tensors_[prefill_forward_name_].clear ();
997
+ output_tensors_[prefill_forward_name_].resize (modules_.size ());
998
+
999
+ k_cache_in_[prefill_forward_name_].clear ();
1000
+ v_cache_in_[prefill_forward_name_].clear ();
1001
+ k_cache_out_[prefill_forward_name_].clear ();
1002
+ v_cache_out_[prefill_forward_name_].clear ();
1003
+
1004
+ switch (eval_mode_) {
1005
+ case EvalMode::kKVCached :
1006
+ prepare_kv_io (prefill_methods_meta);
1007
+ break ;
1008
+ case EvalMode::kHybrid :
1009
+ prepare_prefill_io (prefill_methods_meta);
1010
+ prepare_kv_io (kv_methods_meta);
1011
+ break ;
1012
+ default :
1013
+ ET_CHECK_MSG (false , " unsupported mode" );
1014
+ break ;
1015
+ }
1016
+ }
1017
+
888
1018
void SmartMaskIoMgr::prepare_kv_io (
889
1019
const std::vector<Result<MethodMeta>>& methods_meta) {
890
1020
for (int i = 0 ; i < modules_.size (); ++i) {
0 commit comments