Skip to content

Commit d637e62

Browse files
ggerganoviThalay
authored andcommitted
whisper : add batched decoding (ggml-org#1486)
* whisper : add whisper_batch * whisper : move kv_self to whisper_state * whisper : full batched decoding support * whisper : fix memory leak in whisper_batch * whisper : fix mem leak again + remove oboslete function * whisper : clear kv cache when using whisper_decode API * whisper : speed-up sampling * whisper : fix decoders initializer * bench : add batch size 5 bench * whisper : add comment about the KV cache size * whisper : add check for max number of decoders * whisper : avoid starting sampling threads with bs=1 * whisper : enable beam-search by default * cuda : sync llama.cpp fixes
1 parent 0203e11 commit d637e62

File tree

7 files changed

+826
-562
lines changed

7 files changed

+826
-562
lines changed

examples/bench/bench.cpp

+20-10
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ int whisper_bench_full(const whisper_params & params) {
8181
}
8282
// heat encoder
8383
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
84-
fprintf(stderr, "error: failed to encode model: %d\n", ret);
84+
fprintf(stderr, "error: failed to encode: %d\n", ret);
8585
return 4;
8686
}
8787

@@ -90,34 +90,44 @@ int whisper_bench_full(const whisper_params & params) {
9090

9191
// prompt heat
9292
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
93-
fprintf(stderr, "error: failed to encode model: %d\n", ret);
93+
fprintf(stderr, "error: failed to decode: %d\n", ret);
9494
return 4;
9595
}
9696

9797
// text-generation heat
9898
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
99-
fprintf(stderr, "error: failed to encode model: %d\n", ret);
99+
fprintf(stderr, "error: failed to decode: %d\n", ret);
100100
return 4;
101101
}
102102

103103
whisper_reset_timings(ctx);
104104

105105
// actual run
106106
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
107-
fprintf(stderr, "error: failed to encode model: %d\n", ret);
107+
fprintf(stderr, "error: failed to encode: %d\n", ret);
108108
return 4;
109109
}
110110

111-
for (int i = 0; i < 16; i++) {
112-
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
113-
fprintf(stderr, "error: failed to encode model: %d\n", ret);
111+
// text-generation
112+
for (int i = 0; i < 256; i++) {
113+
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
114+
fprintf(stderr, "error: failed to decode: %d\n", ret);
114115
return 4;
115116
}
116117
}
117118

118-
for (int i = 0; i < 256; i++) {
119-
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
120-
fprintf(stderr, "error: failed to encode model: %d\n", ret);
119+
// batched decoding
120+
for (int i = 0; i < 64; i++) {
121+
if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) {
122+
fprintf(stderr, "error: failed to decode: %d\n", ret);
123+
return 4;
124+
}
125+
}
126+
127+
// prompt processing
128+
for (int i = 0; i < 16; i++) {
129+
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
130+
fprintf(stderr, "error: failed to decode: %d\n", ret);
121131
return 4;
122132
}
123133
}

examples/main/main.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ struct whisper_params {
6262
int32_t progress_step = 5;
6363
int32_t max_context = -1;
6464
int32_t max_len = 0;
65-
int32_t best_of = 2;
66-
int32_t beam_size = -1;
65+
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
66+
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
6767

6868
float word_thold = 0.01f;
6969
float entropy_thold = 2.40f;
@@ -925,9 +925,9 @@ int main(int argc, char ** argv) {
925925
if (params.detect_language) {
926926
params.language = "auto";
927927
}
928-
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
928+
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
929929
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
930-
params.n_threads, params.n_processors,
930+
params.n_threads, params.n_processors, params.beam_size, params.best_of,
931931
params.language.c_str(),
932932
params.translate ? "translate" : "transcribe",
933933
params.tinydiarize ? "tdrz = 1, " : "",

extra/bench-all.sh

+4-3
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then
4444
printf "\n"
4545
fi
4646

47-
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
48-
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
47+
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit"
48+
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"
4949

5050
for model in "${models[@]}"; do
5151
# actual run
@@ -56,6 +56,7 @@ for model in "${models[@]}"; do
5656
# parse the output:
5757
encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
5858
decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
59+
batchd_time=$(echo "$output" | grep "batchd time" | awk '{print $11}')
5960
prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
6061
system_info=$(echo "$output" | grep "system_info")
6162
n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
@@ -94,6 +95,6 @@ for model in "${models[@]}"; do
9495
commit=$(git rev-parse --short HEAD)
9596

9697
if [ $ret -eq 0 ]; then
97-
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
98+
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
9899
fi
99100
done

0 commit comments

Comments
 (0)