@@ -1969,32 +1969,24 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1969
1969
0 , 2 , 1 , 3 );
1970
1970
1971
1971
#ifdef WHISPER_USE_FLASH_ATTN
1972
- struct ggml_tensor * Kpad = ggml_reshape_3d (ctx0, kv_pad.k , n_state_head, n_ctx_pad, n_head);
1972
+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, Kcur, ggml_view_1d (ctx0, kv_pad.k , n_ctx*n_state, 0 )));
1973
+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, Vcur, ggml_view_1d (ctx0, kv_pad.v , n_ctx*n_state, 0 )));
1973
1974
1974
1975
struct ggml_tensor * K =
1975
- ggml_cpy (ctx0,
1976
- ggml_permute (ctx0,
1977
- ggml_reshape_3d (ctx0, Kcur, n_state_head, n_head, n_ctx),
1978
- 0 , 2 , 1 , 3 ),
1979
- ggml_view_3d (ctx0,
1980
- Kpad,
1981
- n_state_head, n_ctx, n_head, Kpad->nb [1 ], Kpad->nb [2 ], 0 ));
1982
-
1983
- struct ggml_tensor * Vpad = ggml_reshape_3d (ctx0, kv_pad.v , n_state_head, n_ctx_pad, n_head);
1976
+ ggml_view_3d (ctx0, kv_pad.k ,
1977
+ n_state_head, n_ctx_pad, n_head,
1978
+ ggml_element_size (kv_pad.k )*n_state,
1979
+ ggml_element_size (kv_pad.k )*n_state_head,
1980
+ 0 );
1984
1981
1985
1982
struct ggml_tensor * V =
1986
- ggml_cpy (ctx0,
1987
- ggml_permute (ctx0,
1988
- ggml_reshape_3d (ctx0, Vcur, n_state_head, n_head, n_ctx),
1989
- 0 , 2 , 1 , 3 ),
1990
- ggml_view_3d (ctx0,
1991
- Vpad,
1992
- n_state_head, n_ctx, n_head, Vpad->nb [1 ], Vpad->nb [2 ], 0 ));
1993
-
1994
- ggml_build_forward_expand (gf, K);
1995
- ggml_build_forward_expand (gf, V);
1983
+ ggml_view_3d (ctx0, kv_pad.v ,
1984
+ n_state_head, n_ctx_pad, n_head,
1985
+ ggml_element_size (kv_pad.v )*n_state,
1986
+ ggml_element_size (kv_pad.v )*n_state_head,
1987
+ 0 );
1996
1988
1997
- cur = ggml_flash_attn_ext (ctx0, Q, Kpad, Vpad , nullptr , KQscale, 0 .0f );
1989
+ cur = ggml_flash_attn_ext (ctx0, Q, K, V , nullptr , KQscale, 0 .0f );
1998
1990
1999
1991
cur = ggml_reshape_2d (ctx0, cur, n_state, n_ctx);
2000
1992
#else
@@ -2519,7 +2511,6 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2519
2511
0 , 2 , 1 , 3 );
2520
2512
2521
2513
#ifdef WHISPER_USE_FLASH_ATTN
2522
- // Kcross is already scaled
2523
2514
struct ggml_tensor * Kcross =
2524
2515
ggml_view_3d (ctx0, wstate.kv_cross .k ,
2525
2516
n_state_head, n_audio_ctx_pad, n_head,
0 commit comments