@@ -1466,17 +1466,12 @@ static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
1466
1466
return 0;
1467
1467
}
1468
1468
1469
- static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
1470
- if (c0 < 0) c0 = 0;
1471
- if (c1 < 0) c1 = cache.size;
1472
-
1473
- for (int32_t i = c0; i < c1; ++i) {
1469
+ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
1470
+ for (int32_t i = 0; i < cache.size; ++i) {
1474
1471
cache.cells[i].pos = -1;
1475
1472
cache.cells[i].seq_id.clear();
1476
1473
}
1477
-
1478
- // Searching for a free slot can start here since we know it will be empty.
1479
- cache.head = uint32_t(c0);
1474
+ cache.head = 0;
1480
1475
}
1481
1476
1482
1477
static void llama_kv_cache_seq_rm(
@@ -1490,8 +1485,14 @@ static void llama_kv_cache_seq_rm(
1490
1485
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
1491
1486
1492
1487
for (uint32_t i = 0; i < cache.size; ++i) {
1493
- if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
1494
- cache.cells[i].seq_id.erase(seq_id);
1488
+ if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
1489
+ if (seq_id < 0) {
1490
+ cache.cells[i].seq_id.clear();
1491
+ } else if (cache.cells[i].has_seq_id(seq_id)) {
1492
+ cache.cells[i].seq_id.erase(seq_id);
1493
+ } else {
1494
+ continue;
1495
+ }
1495
1496
if (cache.cells[i].seq_id.empty()) {
1496
1497
cache.cells[i].pos = -1;
1497
1498
if (new_head == cache.size) new_head = i;
@@ -9207,8 +9208,8 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
9207
9208
return ctx->kv_self.head;
9208
9209
}
9209
9210
9210
- void llama_kv_cache_tokens_rm (struct llama_context * ctx, int32_t c0, int32_t c1 ) {
9211
- llama_kv_cache_tokens_rm (ctx->kv_self, c0, c1 );
9211
+ void llama_kv_cache_clear (struct llama_context * ctx) {
9212
+ llama_kv_cache_clear (ctx->kv_self);
9212
9213
}
9213
9214
9214
9215
void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
@@ -9654,7 +9655,7 @@ int llama_eval(
9654
9655
llama_token * tokens,
9655
9656
int32_t n_tokens,
9656
9657
int n_past) {
9657
- llama_kv_cache_tokens_rm (ctx->kv_self, n_past, -1);
9658
+ llama_kv_cache_seq_rm (ctx->kv_self, -1 , n_past, -1);
9658
9659
9659
9660
const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0));
9660
9661
if (ret < 0) {
@@ -9669,7 +9670,7 @@ int llama_eval_embd(
9669
9670
float * embd,
9670
9671
int32_t n_tokens,
9671
9672
int n_past) {
9672
- llama_kv_cache_tokens_rm (ctx->kv_self, n_past, -1);
9673
+ llama_kv_cache_seq_rm (ctx->kv_self, -1 , n_past, -1);
9673
9674
9674
9675
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
9675
9676
0 commit comments