Skip to content

Commit d5295b4

Browse files
committed
Removed unnecessary reshapes when retrieving kv from cache
1 parent d8c51b2 commit d5295b4

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

examples/falcon/main.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -527,14 +527,13 @@ bool falcon_eval(
527527

528528
struct ggml_tensor * K = ggml_permute(
529529
ctx0,
530-
ggml_reshape_3d(
530+
ggml_view_3d(
531531
ctx0,
532-
ggml_view_1d(ctx0, model.memory_k, (n_past + N) * n_head_kv * head_dim,
533-
il * n_ctx *
534-
ggml_element_size(model.memory_k) *
535-
n_head_kv *
536-
head_dim),
537-
head_dim, n_head_kv, n_past + N),
532+
model.memory_k,
533+
head_dim, n_head_kv, n_past + N,
534+
head_dim * sizeof_wtype,
535+
head_dim * n_head_kv * sizeof_wtype,
536+
il * n_ctx * ggml_element_size(model.memory_k) * n_head_kv * head_dim),
538537
0, 2, 1, 3);
539538

540539
// K * Q
@@ -560,14 +559,13 @@ bool falcon_eval(
560559
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
561560
struct ggml_tensor* V = ggml_permute(
562561
ctx0,
563-
ggml_reshape_3d(
562+
ggml_view_3d(
564563
ctx0,
565-
ggml_view_1d(ctx0, model.memory_v, (n_past + N) * n_head_kv * head_dim,
566-
il * n_ctx *
567-
ggml_element_size(model.memory_v) *
568-
n_head_kv *
569-
head_dim),
570-
head_dim, n_head_kv, n_past + N),
564+
model.memory_v,
565+
head_dim, n_head_kv, n_past + N,
566+
head_dim * sizeof_wtype,
567+
head_dim * n_head_kv * sizeof_wtype,
568+
il * n_ctx * ggml_element_size(model.memory_v) * n_head_kv * head_dim),
571569
0, 2, 1, 3);
572570

573571
V = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_repeat2(ctx0, V, repeat_dummy)));

0 commit comments

Comments
 (0)