@@ -279,8 +279,18 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
279
279
break ;
280
280
}
281
281
params.yarn_beta_slow = std::stof (argv[i]);
282
- } else if (arg == " --memory-f32" ) {
283
- params.memory_f16 = false ;
282
+ } else if (arg == " --samplers" ) {
283
+ if (++i >= argc) {
284
+ invalid_param = true ;
285
+ break ;
286
+ }
287
+ sparams.samplers_sequence = parse_samplers_input (argv[i]);
288
+ } else if (arg == " --sampling-seq" ) {
289
+ if (++i >= argc) {
290
+ invalid_param = true ;
291
+ break ;
292
+ }
293
+ sparams.samplers_sequence = argv[i];
284
294
} else if (arg == " --top-p" ) {
285
295
if (++i >= argc) {
286
296
invalid_param = true ;
@@ -499,6 +509,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
499
509
params.infill = true ;
500
510
} else if (arg == " -dkvc" || arg == " --dump-kv-cache" ) {
501
511
params.dump_kv_cache = true ;
512
+ } else if (arg == " -nkvo" || arg == " --no-kv-offload" ) {
513
+ params.no_kv_offload = true ;
514
+ } else if (arg == " -ctk" || arg == " --cache-type-k" ) {
515
+ params.cache_type_k = argv[++i];
516
+ } else if (arg == " -ctv" || arg == " --cache-type-v" ) {
517
+ params.cache_type_v = argv[++i];
502
518
} else if (arg == " --multiline-input" ) {
503
519
params.multiline_input = true ;
504
520
} else if (arg == " --simple-io" ) {
@@ -679,6 +695,47 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
679
695
std::istreambuf_iterator<char >(),
680
696
std::back_inserter (sparams.grammar )
681
697
);
698
+ } else if (arg == " --override-kv" ) {
699
+ if (++i >= argc) {
700
+ invalid_param = true ;
701
+ break ;
702
+ }
703
+ char * sep = strchr (argv[i], ' =' );
704
+ if (sep == nullptr || sep - argv[i] >= 128 ) {
705
+ fprintf (stderr, " error: Malformed KV override: %s\n " , argv[i]);
706
+ invalid_param = true ;
707
+ break ;
708
+ }
709
+ struct llama_model_kv_override kvo;
710
+ std::strncpy (kvo.key , argv[i], sep - argv[i]);
711
+ kvo.key [sep - argv[i]] = 0 ;
712
+ sep++;
713
+ if (strncmp (sep, " int:" , 4 ) == 0 ) {
714
+ sep += 4 ;
715
+ kvo.tag = LLAMA_KV_OVERRIDE_INT;
716
+ kvo.int_value = std::atol (sep);
717
+ } else if (strncmp (sep, " float:" , 6 ) == 0 ) {
718
+ sep += 6 ;
719
+ kvo.tag = LLAMA_KV_OVERRIDE_FLOAT;
720
+ kvo.float_value = std::atof (sep);
721
+ } else if (strncmp (sep, " bool:" , 5 ) == 0 ) {
722
+ sep += 5 ;
723
+ kvo.tag = LLAMA_KV_OVERRIDE_BOOL;
724
+ if (std::strcmp (sep, " true" ) == 0 ) {
725
+ kvo.bool_value = true ;
726
+ } else if (std::strcmp (sep, " false" ) == 0 ) {
727
+ kvo.bool_value = false ;
728
+ } else {
729
+ fprintf (stderr, " error: Invalid boolean value for KV override: %s\n " , argv[i]);
730
+ invalid_param = true ;
731
+ break ;
732
+ }
733
+ } else {
734
+ fprintf (stderr, " error: Invalid type for KV override: %s\n " , argv[i]);
735
+ invalid_param = true ;
736
+ break ;
737
+ }
738
+ params.kv_overrides .push_back (kvo);
682
739
#ifndef LOG_DISABLE_LOGS
683
740
// Parse args for logging parameters
684
741
} else if ( log_param_single_parse ( argv[i] ) ) {
@@ -722,6 +779,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
722
779
}
723
780
}
724
781
782
+ if (!params.kv_overrides .empty ()) {
783
+ params.kv_overrides .emplace_back (llama_model_kv_override ());
784
+ params.kv_overrides .back ().key [0 ] = 0 ;
785
+ }
786
+
725
787
return true ;
726
788
}
727
789
@@ -762,6 +824,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
762
824
printf (" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n " , params.n_predict );
763
825
printf (" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n " , params.n_ctx );
764
826
printf (" -b N, --batch-size N batch size for prompt processing (default: %d)\n " , params.n_batch );
827
+ printf (" --samplers samplers that will be used for generation in the order, separated by \' ;\' , for example: \" top_k;tfs;typical;top_p;min_p;temp\"\n " );
828
+ printf (" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n " , sparams.samplers_sequence .c_str ());
765
829
printf (" --top-k N top-k sampling (default: %d, 0 = disabled)\n " , sparams.top_k );
766
830
printf (" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n " , (double )sparams.top_p );
767
831
printf (" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n " , (double )sparams.min_p );
@@ -799,8 +863,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
799
863
printf (" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n " , params.yarn_beta_fast );
800
864
printf (" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n " );
801
865
printf (" --no-penalize-nl do not penalize newline token\n " );
802
- printf (" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n " );
803
- printf (" not recommended: doubles context memory required and no measurable increase in quality\n " );
804
866
printf (" --temp N temperature (default: %.1f)\n " , (double )sparams.temp );
805
867
printf (" --logits-all return logits for all tokens in the batch (default: disabled)\n " );
806
868
printf (" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n " );
@@ -841,6 +903,12 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
841
903
printf (" --verbose-prompt print prompt before generation\n " );
842
904
printf (" -dkvc, --dump-kv-cache\n " );
843
905
printf (" verbose print of the KV cache\n " );
906
+ printf (" -nkvo, --no-kv-offload\n " );
907
+ printf (" disable KV offload\n " );
908
+ printf (" -ctk TYPE, --cache-type-k TYPE\n " );
909
+ printf (" KV cache data type for K (default: %s)\n " , params.cache_type_k .c_str ());
910
+ printf (" -ctv TYPE, --cache-type-v TYPE\n " );
911
+ printf (" KV cache data type for V (default: %s)\n " , params.cache_type_v .c_str ());
844
912
printf (" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n " );
845
913
printf (" --lora FNAME apply LoRA adapter (implies --no-mmap)\n " );
846
914
printf (" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n " );
@@ -851,6 +919,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
851
919
printf (" draft model for speculative decoding (default: %s)\n " , params.model .c_str ());
852
920
printf (" -ld LOGDIR, --logdir LOGDIR\n " );
853
921
printf (" path under which to save YAML logs (no logging if unset)\n " );
922
+ printf (" --override-kv KEY=TYPE:VALUE\n " );
923
+ printf (" advanced option to override model metadata by key. may be specified multiple times.\n " );
924
+ printf (" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n " );
854
925
printf (" \n " );
855
926
#ifndef LOG_DISABLE_LOGS
856
927
log_print_usage ();
@@ -887,6 +958,48 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
887
958
GGML_UNREACHABLE ();
888
959
}
889
960
961
+ //
962
+ // String parsing
963
+ //
964
+
965
+ std::string parse_samplers_input (std::string input) {
966
+ std::string output = " " ;
967
+ // since samplers names are written multiple ways
968
+ // make it ready for both system names and input names
969
+ std::unordered_map<std::string, char > samplers_symbols {
970
+ {" top_k" , ' k' },
971
+ {" top-k" , ' k' },
972
+ {" top_p" , ' p' },
973
+ {" top-p" , ' p' },
974
+ {" nucleus" , ' p' },
975
+ {" typical_p" , ' y' },
976
+ {" typical-p" , ' y' },
977
+ {" typical" , ' y' },
978
+ {" min_p" , ' m' },
979
+ {" min-p" , ' m' },
980
+ {" tfs_z" , ' f' },
981
+ {" tfs-z" , ' f' },
982
+ {" tfs" , ' f' },
983
+ {" temp" , ' t' },
984
+ {" temperature" ,' t' }
985
+ };
986
+ // expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p"
987
+ size_t separator = input.find (' ;' );
988
+ while (separator != input.npos ) {
989
+ std::string name = input.substr (0 ,separator);
990
+ input = input.substr (separator+1 );
991
+ separator = input.find (' ;' );
992
+
993
+ if (samplers_symbols.find (name) != samplers_symbols.end ()) {
994
+ output += samplers_symbols[name];
995
+ }
996
+ }
997
+ if (samplers_symbols.find (input) != samplers_symbols.end ()) {
998
+ output += samplers_symbols[input];
999
+ }
1000
+ return output;
1001
+ }
1002
+
890
1003
//
891
1004
// Model utils
892
1005
//
@@ -901,10 +1014,39 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
901
1014
mparams.tensor_split = params.tensor_split ;
902
1015
mparams.use_mmap = params.use_mmap ;
903
1016
mparams.use_mlock = params.use_mlock ;
1017
+ if (params.kv_overrides .empty ()) {
1018
+ mparams.kv_overrides = NULL ;
1019
+ } else {
1020
+ GGML_ASSERT (params.kv_overrides .back ().key [0 ] == 0 && " KV overrides not terminated with empty key" );
1021
+ mparams.kv_overrides = params.kv_overrides .data ();
1022
+ }
904
1023
905
1024
return mparams;
906
1025
}
907
1026
1027
+ static ggml_type kv_cache_type_from_str (const std::string & s) {
1028
+ if (s == " f16" ) {
1029
+ return GGML_TYPE_F16;
1030
+ }
1031
+ if (s == " q8_0" ) {
1032
+ return GGML_TYPE_Q8_0;
1033
+ }
1034
+ if (s == " q4_0" ) {
1035
+ return GGML_TYPE_Q4_0;
1036
+ }
1037
+ if (s == " q4_1" ) {
1038
+ return GGML_TYPE_Q4_1;
1039
+ }
1040
+ if (s == " q5_0" ) {
1041
+ return GGML_TYPE_Q5_0;
1042
+ }
1043
+ if (s == " q5_1" ) {
1044
+ return GGML_TYPE_Q5_1;
1045
+ }
1046
+
1047
+ throw std::runtime_error (" Invalid cache type: " + s);
1048
+ }
1049
+
908
1050
struct llama_context_params llama_context_params_from_gpt_params (const gpt_params & params) {
909
1051
auto cparams = llama_context_default_params ();
910
1052
@@ -914,7 +1056,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
914
1056
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch ;
915
1057
cparams.mul_mat_q = params.mul_mat_q ;
916
1058
cparams.seed = params.seed ;
917
- cparams.f16_kv = params.memory_f16 ;
918
1059
cparams.logits_all = params.logits_all ;
919
1060
cparams.embedding = params.embedding ;
920
1061
cparams.rope_scaling_type = params.rope_scaling_type ;
@@ -925,6 +1066,10 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
925
1066
cparams.yarn_beta_fast = params.yarn_beta_fast ;
926
1067
cparams.yarn_beta_slow = params.yarn_beta_slow ;
927
1068
cparams.yarn_orig_ctx = params.yarn_orig_ctx ;
1069
+ cparams.offload_kqv = !params.no_kv_offload ;
1070
+
1071
+ cparams.type_k = kv_cache_type_from_str (params.cache_type_k );
1072
+ cparams.type_v = kv_cache_type_from_str (params.cache_type_v );
928
1073
929
1074
return cparams;
930
1075
}
@@ -1337,7 +1482,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
1337
1482
}
1338
1483
fprintf (stream, " lora_base: %s\n " , params.lora_base .c_str ());
1339
1484
fprintf (stream, " main_gpu: %d # default: 0\n " , params.main_gpu );
1340
- fprintf (stream, " memory_f32: %s # default: false\n " , !params.memory_f16 ? " true" : " false" );
1341
1485
fprintf (stream, " mirostat: %d # default: 0 (disabled)\n " , sparams.mirostat );
1342
1486
fprintf (stream, " mirostat_ent: %f # default: 5.0\n " , sparams.mirostat_tau );
1343
1487
fprintf (stream, " mirostat_lr: %f # default: 0.1\n " , sparams.mirostat_eta );
0 commit comments