Skip to content

Commit 14262ab

Browse files
committed
whisper : simplify encoder FA
1 parent 5b8b92c commit 14262ab

File tree

1 file changed

+13
-22
lines changed

1 file changed

+13
-22
lines changed

whisper.cpp

+13-22
Original file line numberDiff line numberDiff line change
@@ -1969,32 +1969,24 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
19691969
0, 2, 1, 3);
19701970

19711971
#ifdef WHISPER_USE_FLASH_ATTN
1972-
struct ggml_tensor * Kpad = ggml_reshape_3d(ctx0, kv_pad.k, n_state_head, n_ctx_pad, n_head);
1972+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0)));
1973+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0)));
19731974

19741975
struct ggml_tensor * K =
1975-
ggml_cpy(ctx0,
1976-
ggml_permute(ctx0,
1977-
ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
1978-
0, 2, 1, 3),
1979-
ggml_view_3d(ctx0,
1980-
Kpad,
1981-
n_state_head, n_ctx, n_head, Kpad->nb[1], Kpad->nb[2], 0));
1982-
1983-
struct ggml_tensor * Vpad = ggml_reshape_3d(ctx0, kv_pad.v, n_state_head, n_ctx_pad, n_head);
1976+
ggml_view_3d(ctx0, kv_pad.k,
1977+
n_state_head, n_ctx_pad, n_head,
1978+
ggml_element_size(kv_pad.k)*n_state,
1979+
ggml_element_size(kv_pad.k)*n_state_head,
1980+
0);
19841981

19851982
struct ggml_tensor * V =
1986-
ggml_cpy(ctx0,
1987-
ggml_permute(ctx0,
1988-
ggml_reshape_3d(ctx0, Vcur, n_state_head, n_head, n_ctx),
1989-
0, 2, 1, 3),
1990-
ggml_view_3d(ctx0,
1991-
Vpad,
1992-
n_state_head, n_ctx, n_head, Vpad->nb[1], Vpad->nb[2], 0));
1993-
1994-
ggml_build_forward_expand(gf, K);
1995-
ggml_build_forward_expand(gf, V);
1983+
ggml_view_3d(ctx0, kv_pad.v,
1984+
n_state_head, n_ctx_pad, n_head,
1985+
ggml_element_size(kv_pad.v)*n_state,
1986+
ggml_element_size(kv_pad.v)*n_state_head,
1987+
0);
19961988

1997-
cur = ggml_flash_attn_ext(ctx0, Q, Kpad, Vpad, nullptr, KQscale, 0.0f);
1989+
cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f);
19981990

19991991
cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
20001992
#else
@@ -2519,7 +2511,6 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
25192511
0, 2, 1, 3);
25202512

25212513
#ifdef WHISPER_USE_FLASH_ATTN
2522-
// Kcross is already scaled
25232514
struct ggml_tensor * Kcross =
25242515
ggml_view_3d(ctx0, wstate.kv_cross.k,
25252516
n_state_head, n_audio_ctx_pad, n_head,

0 commit comments

Comments
 (0)