Skip to content

Commit e9cacb3

Browse files
sandrohaneaSandro Haneaggerganov
authored
whisper : add whisper_state + default state on the whisper_context (ggml-org#523)
* Added whisper state + default state on the whisper_context * Fixed some examples and bindings * Fixed whisper_n_len (which was used in some binding) and added whisper_n_len_from_state * Fixed comments * whisper : reuse kv_cache_free() and fix compiler warnings * whisper : clean-up the API comments --------- Co-authored-by: Sandro Hanea <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 7eebfb1 commit e9cacb3

File tree

6 files changed

+701
-453
lines changed

6 files changed

+701
-453
lines changed

bindings/go/whisper.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ extern bool callEncoderBegin(void* user_data);
2020
// Text segment callback
2121
// Called on every newly generated text segment
2222
// Use the whisper_full_...() functions to obtain the text segments
23-
static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
23+
static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_state* state, int n_new, void* user_data) {
2424
if(user_data != NULL && ctx != NULL) {
2525
callNewSegment(user_data, n_new);
2626
}
@@ -29,7 +29,7 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void*
2929
// Encoder begin callback
3030
// If not NULL, called before the encoder starts
3131
// If it returns false, the computation is aborted
32-
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
32+
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, struct whisper_state* state, void* user_data) {
3333
if(user_data != NULL && ctx != NULL) {
3434
return callEncoderBegin(user_data);
3535
}

bindings/ruby/ext/ruby_whisper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
199199
{
200200
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
201201

202-
rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
202+
rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
203203
bool is_aborted = *(bool*)user_data;
204204
return !is_aborted;
205205
};

examples/addon.node/addon.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ int timestamp_to_sample(int64_t t, int n_samples) {
7272
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
7373
}
7474

75-
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
75+
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
7676
const auto & params = *((whisper_print_user_data *) user_data)->params;
7777
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
7878

@@ -260,7 +260,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
260260
{
261261
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
262262

263-
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
263+
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
264264
bool is_aborted = *(bool*)user_data;
265265
return !is_aborted;
266266
};

examples/main/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ struct whisper_print_user_data {
193193
const std::vector<std::vector<float>> * pcmf32s;
194194
};
195195

196-
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
196+
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
197197
const auto & params = *((whisper_print_user_data *) user_data)->params;
198198
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
199199

@@ -608,7 +608,7 @@ int main(int argc, char ** argv) {
608608
{
609609
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
610610

611-
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
611+
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
612612
bool is_aborted = *(bool*)user_data;
613613
return !is_aborted;
614614
};

0 commit comments

Comments
 (0)