1
1
#include " common.h"
2
2
3
3
#include " whisper.h"
4
+ #include " grammar-parser.h"
4
5
5
6
#include < cmath>
6
7
#include < fstream>
@@ -38,9 +39,10 @@ struct whisper_params {
38
39
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
39
40
int32_t audio_ctx = 0 ;
40
41
41
- float word_thold = 0 .01f ;
42
- float entropy_thold = 2 .40f ;
43
- float logprob_thold = -1 .00f ;
42
+ float word_thold = 0 .01f ;
43
+ float entropy_thold = 2 .40f ;
44
+ float logprob_thold = -1 .00f ;
45
+ float grammar_penalty = 100 .0f ;
44
46
45
47
bool speed_up = false ;
46
48
bool debug_mode = false ;
@@ -70,6 +72,8 @@ struct whisper_params {
70
72
std::string prompt;
71
73
std::string font_path = " /System/Library/Fonts/Supplemental/Courier New Bold.ttf" ;
72
74
std::string model = " models/ggml-base.en.bin" ;
75
+ std::string grammar;
76
+ std::string grammar_rule;
73
77
74
78
// [TDRZ] speaker turn string
75
79
std::string tdrz_speaker_turn = " [SPEAKER_TURN]" ; // TODO: set from command line
@@ -80,6 +84,8 @@ struct whisper_params {
80
84
81
85
std::vector<std::string> fname_inp = {};
82
86
std::vector<std::string> fname_out = {};
87
+
88
+ grammar_parser::parse_state grammar_parsed;
83
89
};
84
90
85
91
void whisper_print_usage (int argc, char ** argv, const whisper_params & params);
@@ -154,6 +160,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
154
160
else if (arg == " -dtw" || arg == " --dtw" ) { params.dtw = argv[++i]; }
155
161
else if (arg == " -ls" || arg == " --log-score" ) { params.log_score = true ; }
156
162
else if (arg == " -ng" || arg == " --no-gpu" ) { params.use_gpu = false ; }
163
+ else if ( arg == " --grammar" ) { params.grammar = argv[++i]; }
164
+ else if ( arg == " --grammar-rule" ) { params.grammar_rule = argv[++i]; }
165
+ else if ( arg == " --grammar-penalty" ) { params.grammar_penalty = std::stof (argv[++i]); }
157
166
else {
158
167
fprintf (stderr, " error: unknown argument: %s\n " , arg.c_str ());
159
168
whisper_print_usage (argc, argv, params);
@@ -214,6 +223,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
214
223
fprintf (stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n " , params.dtw .c_str ());
215
224
fprintf (stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n " , params.log_score ?" true" :" false" );
216
225
fprintf (stderr, " -ng, --no-gpu [%-7s] disable GPU\n " , params.use_gpu ? " false" : " true" );
226
+ fprintf (stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n " , params.grammar .c_str ());
227
+ fprintf (stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n " , params.grammar_rule .c_str ());
228
+ fprintf (stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n " , params.grammar_penalty );
217
229
fprintf (stderr, " \n " );
218
230
}
219
231
@@ -926,6 +938,29 @@ int main(int argc, char ** argv) {
926
938
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
927
939
whisper_ctx_init_openvino_encoder (ctx, nullptr , params.openvino_encode_device .c_str (), nullptr );
928
940
941
+ if (!params.grammar .empty ()) {
942
+ auto & grammar = params.grammar_parsed ;
943
+ if (is_file_exist (params.grammar .c_str ())) {
944
+ // read grammar from file
945
+ std::ifstream ifs (params.grammar .c_str ());
946
+ const std::string txt = std::string ((std::istreambuf_iterator<char >(ifs)), std::istreambuf_iterator<char >());
947
+ grammar = grammar_parser::parse (txt.c_str ());
948
+ } else {
949
+ // read grammar from string
950
+ grammar = grammar_parser::parse (params.grammar .c_str ());
951
+ }
952
+
953
+ // will be empty (default) if there are parse errors
954
+ if (grammar.rules .empty ()) {
955
+ fprintf (stderr, " error: failed to parse grammar \" %s\"\n " , params.grammar .c_str ());
956
+ return 4 ;
957
+ } else {
958
+ fprintf (stderr, " %s: grammar:\n " , __func__);
959
+ grammar_parser::print_grammar (stderr, grammar);
960
+ fprintf (stderr, " \n " );
961
+ }
962
+ }
963
+
929
964
for (int f = 0 ; f < (int ) params.fname_inp .size (); ++f) {
930
965
const auto fname_inp = params.fname_inp [f];
931
966
const auto fname_out = f < (int ) params.fname_out .size () && !params.fname_out [f].empty () ? params.fname_out [f] : params.fname_inp [f];
@@ -972,7 +1007,8 @@ int main(int argc, char ** argv) {
972
1007
{
973
1008
whisper_full_params wparams = whisper_full_default_params (WHISPER_SAMPLING_GREEDY);
974
1009
975
- wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
1010
+ const bool use_grammar = (!params.grammar_parsed .rules .empty () && !params.grammar_rule .empty ());
1011
+ wparams.strategy = (params.beam_size > 1 || use_grammar) ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
976
1012
977
1013
wparams.print_realtime = false ;
978
1014
wparams.print_progress = params.print_progress ;
@@ -1010,6 +1046,20 @@ int main(int argc, char ** argv) {
1010
1046
1011
1047
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
1012
1048
1049
+ const auto & grammar_parsed = params.grammar_parsed ;
1050
+ auto grammar_rules = grammar_parsed.c_rules ();
1051
+
1052
+ if (use_grammar) {
1053
+ if (grammar_parsed.symbol_ids .find (params.grammar_rule ) == grammar_parsed.symbol_ids .end ()) {
1054
+ fprintf (stderr, " %s: warning: grammar rule '%s' not found - skipping grammar sampling\n " , __func__, params.grammar_rule .c_str ());
1055
+ } else {
1056
+ wparams.grammar_rules = grammar_rules.data ();
1057
+ wparams.n_grammar_rules = grammar_rules.size ();
1058
+ wparams.i_start_rule = grammar_parsed.symbol_ids .at (params.grammar_rule );
1059
+ wparams.grammar_penalty = params.grammar_penalty ;
1060
+ }
1061
+ }
1062
+
1013
1063
// this callback is called on each new segment
1014
1064
if (!wparams.print_realtime ) {
1015
1065
wparams.new_segment_callback = whisper_print_segment_callback;
0 commit comments