Skip to content

Commit ec21fa7

Browse files
committed
Merge branch 'master' into concedo_experimental
# Conflicts: # .github/workflows/build.yml # .gitignore # CMakeLists.txt # Makefile # Package.swift # README.md # ggml-cuda.cu # llama.cpp # llama.h # scripts/sync-ggml.sh # tests/CMakeLists.txt
2 parents 930cdfb + fe680e3 commit ec21fa7

34 files changed

+5822
-1370
lines changed

Diff for: .gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ poetry.toml
8181

8282
# Test binaries
8383
tests/test-grammar-parser
84+
/tests/test-llama-grammar
8485
tests/test-double-float
8586
tests/test-grad0
8687
tests/test-opt
@@ -92,6 +93,8 @@ tests/test-tokenizer-0-llama
9293
tests/test-tokenizer-0-falcon
9394
tests/test-tokenizer-1-llama
9495
tests/test-tokenizer-1-bpe
96+
/tests/test-rope
97+
/tests/test-backend-ops
9598

9699
/koboldcpp_default.so
97100
/koboldcpp_failsafe.so

Diff for: common/common.cpp

+150-6
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,18 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
279279
break;
280280
}
281281
params.yarn_beta_slow = std::stof(argv[i]);
282-
} else if (arg == "--memory-f32") {
283-
params.memory_f16 = false;
282+
} else if (arg == "--samplers") {
283+
if (++i >= argc) {
284+
invalid_param = true;
285+
break;
286+
}
287+
sparams.samplers_sequence = parse_samplers_input(argv[i]);
288+
} else if (arg == "--sampling-seq") {
289+
if (++i >= argc) {
290+
invalid_param = true;
291+
break;
292+
}
293+
sparams.samplers_sequence = argv[i];
284294
} else if (arg == "--top-p") {
285295
if (++i >= argc) {
286296
invalid_param = true;
@@ -499,6 +509,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
499509
params.infill = true;
500510
} else if (arg == "-dkvc" || arg == "--dump-kv-cache") {
501511
params.dump_kv_cache = true;
512+
} else if (arg == "-nkvo" || arg == "--no-kv-offload") {
513+
params.no_kv_offload = true;
514+
} else if (arg == "-ctk" || arg == "--cache-type-k") {
515+
params.cache_type_k = argv[++i];
516+
} else if (arg == "-ctv" || arg == "--cache-type-v") {
517+
params.cache_type_v = argv[++i];
502518
} else if (arg == "--multiline-input") {
503519
params.multiline_input = true;
504520
} else if (arg == "--simple-io") {
@@ -679,6 +695,47 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
679695
std::istreambuf_iterator<char>(),
680696
std::back_inserter(sparams.grammar)
681697
);
698+
} else if (arg == "--override-kv") {
699+
if (++i >= argc) {
700+
invalid_param = true;
701+
break;
702+
}
703+
char * sep = strchr(argv[i], '=');
704+
if (sep == nullptr || sep - argv[i] >= 128) {
705+
fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]);
706+
invalid_param = true;
707+
break;
708+
}
709+
struct llama_model_kv_override kvo;
710+
std::strncpy(kvo.key, argv[i], sep - argv[i]);
711+
kvo.key[sep - argv[i]] = 0;
712+
sep++;
713+
if (strncmp(sep, "int:", 4) == 0) {
714+
sep += 4;
715+
kvo.tag = LLAMA_KV_OVERRIDE_INT;
716+
kvo.int_value = std::atol(sep);
717+
} else if (strncmp(sep, "float:", 6) == 0) {
718+
sep += 6;
719+
kvo.tag = LLAMA_KV_OVERRIDE_FLOAT;
720+
kvo.float_value = std::atof(sep);
721+
} else if (strncmp(sep, "bool:", 5) == 0) {
722+
sep += 5;
723+
kvo.tag = LLAMA_KV_OVERRIDE_BOOL;
724+
if (std::strcmp(sep, "true") == 0) {
725+
kvo.bool_value = true;
726+
} else if (std::strcmp(sep, "false") == 0) {
727+
kvo.bool_value = false;
728+
} else {
729+
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
730+
invalid_param = true;
731+
break;
732+
}
733+
} else {
734+
fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
735+
invalid_param = true;
736+
break;
737+
}
738+
params.kv_overrides.push_back(kvo);
682739
#ifndef LOG_DISABLE_LOGS
683740
// Parse args for logging parameters
684741
} else if ( log_param_single_parse( argv[i] ) ) {
@@ -722,6 +779,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
722779
}
723780
}
724781

782+
if (!params.kv_overrides.empty()) {
783+
params.kv_overrides.emplace_back(llama_model_kv_override());
784+
params.kv_overrides.back().key[0] = 0;
785+
}
786+
725787
return true;
726788
}
727789

@@ -762,6 +824,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
762824
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
763825
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
764826
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
827+
printf(" --samplers samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n");
828+
printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str());
765829
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
766830
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
767831
printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
@@ -799,8 +863,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
799863
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
800864
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
801865
printf(" --no-penalize-nl do not penalize newline token\n");
802-
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
803-
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
804866
printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp);
805867
printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n");
806868
printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
@@ -841,6 +903,12 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
841903
printf(" --verbose-prompt print prompt before generation\n");
842904
printf(" -dkvc, --dump-kv-cache\n");
843905
printf(" verbose print of the KV cache\n");
906+
printf(" -nkvo, --no-kv-offload\n");
907+
printf(" disable KV offload\n");
908+
printf(" -ctk TYPE, --cache-type-k TYPE\n");
909+
printf(" KV cache data type for K (default: %s)\n", params.cache_type_k.c_str());
910+
printf(" -ctv TYPE, --cache-type-v TYPE\n");
911+
printf(" KV cache data type for V (default: %s)\n", params.cache_type_v.c_str());
844912
printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
845913
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
846914
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
@@ -851,6 +919,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
851919
printf(" draft model for speculative decoding (default: %s)\n", params.model.c_str());
852920
printf(" -ld LOGDIR, --logdir LOGDIR\n");
853921
printf(" path under which to save YAML logs (no logging if unset)\n");
922+
printf(" --override-kv KEY=TYPE:VALUE\n");
923+
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
924+
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
854925
printf("\n");
855926
#ifndef LOG_DISABLE_LOGS
856927
log_print_usage();
@@ -887,6 +958,48 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
887958
GGML_UNREACHABLE();
888959
}
889960

961+
//
962+
// String parsing
963+
//
964+
965+
std::string parse_samplers_input(std::string input) {
966+
std::string output = "";
967+
// since samplers names are written multiple ways
968+
// make it ready for both system names and input names
969+
std::unordered_map<std::string, char> samplers_symbols {
970+
{"top_k", 'k'},
971+
{"top-k", 'k'},
972+
{"top_p", 'p'},
973+
{"top-p", 'p'},
974+
{"nucleus", 'p'},
975+
{"typical_p", 'y'},
976+
{"typical-p", 'y'},
977+
{"typical", 'y'},
978+
{"min_p", 'm'},
979+
{"min-p", 'm'},
980+
{"tfs_z", 'f'},
981+
{"tfs-z", 'f'},
982+
{"tfs", 'f'},
983+
{"temp", 't'},
984+
{"temperature",'t'}
985+
};
986+
// expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p"
987+
size_t separator = input.find(';');
988+
while (separator != input.npos) {
989+
std::string name = input.substr(0,separator);
990+
input = input.substr(separator+1);
991+
separator = input.find(';');
992+
993+
if (samplers_symbols.find(name) != samplers_symbols.end()) {
994+
output += samplers_symbols[name];
995+
}
996+
}
997+
if (samplers_symbols.find(input) != samplers_symbols.end()) {
998+
output += samplers_symbols[input];
999+
}
1000+
return output;
1001+
}
1002+
8901003
//
8911004
// Model utils
8921005
//
@@ -901,10 +1014,39 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
9011014
mparams.tensor_split = params.tensor_split;
9021015
mparams.use_mmap = params.use_mmap;
9031016
mparams.use_mlock = params.use_mlock;
1017+
if (params.kv_overrides.empty()) {
1018+
mparams.kv_overrides = NULL;
1019+
} else {
1020+
GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
1021+
mparams.kv_overrides = params.kv_overrides.data();
1022+
}
9041023

9051024
return mparams;
9061025
}
9071026

1027+
static ggml_type kv_cache_type_from_str(const std::string & s) {
1028+
if (s == "f16") {
1029+
return GGML_TYPE_F16;
1030+
}
1031+
if (s == "q8_0") {
1032+
return GGML_TYPE_Q8_0;
1033+
}
1034+
if (s == "q4_0") {
1035+
return GGML_TYPE_Q4_0;
1036+
}
1037+
if (s == "q4_1") {
1038+
return GGML_TYPE_Q4_1;
1039+
}
1040+
if (s == "q5_0") {
1041+
return GGML_TYPE_Q5_0;
1042+
}
1043+
if (s == "q5_1") {
1044+
return GGML_TYPE_Q5_1;
1045+
}
1046+
1047+
throw std::runtime_error("Invalid cache type: " + s);
1048+
}
1049+
9081050
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
9091051
auto cparams = llama_context_default_params();
9101052

@@ -914,7 +1056,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
9141056
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
9151057
cparams.mul_mat_q = params.mul_mat_q;
9161058
cparams.seed = params.seed;
917-
cparams.f16_kv = params.memory_f16;
9181059
cparams.logits_all = params.logits_all;
9191060
cparams.embedding = params.embedding;
9201061
cparams.rope_scaling_type = params.rope_scaling_type;
@@ -925,6 +1066,10 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
9251066
cparams.yarn_beta_fast = params.yarn_beta_fast;
9261067
cparams.yarn_beta_slow = params.yarn_beta_slow;
9271068
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
1069+
cparams.offload_kqv = !params.no_kv_offload;
1070+
1071+
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
1072+
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
9281073

9291074
return cparams;
9301075
}
@@ -1337,7 +1482,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
13371482
}
13381483
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
13391484
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
1340-
fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
13411485
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
13421486
fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
13431487
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);

Diff for: common/common.h

+13-2
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ struct gpt_params {
9494
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
9595
std::string logdir = ""; // directory in which to save YAML log files
9696

97+
std::vector<llama_model_kv_override> kv_overrides;
98+
9799
// TODO: avoid tuple, use struct
98100
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
99101
std::string lora_base = ""; // base model path for the lora adapter
@@ -106,7 +108,6 @@ struct gpt_params {
106108
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
107109

108110
bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
109-
bool memory_f16 = true; // use f16 instead of f32 for memory kv
110111
bool random_prompt = false; // do not randomize prompt if none provided
111112
bool use_color = false; // use color to distinguish generations and inputs
112113
bool interactive = false; // interactive mode
@@ -131,10 +132,14 @@ struct gpt_params {
131132
bool verbose_prompt = false; // print prompt tokens before generation
132133
bool infill = false; // use infill mode
133134
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
135+
bool no_kv_offload = false; // disable KV offloading
136+
137+
std::string cache_type_k = "f16"; // KV cache data type for the K
138+
std::string cache_type_v = "f16"; // KV cache data type for the V
134139

135140
// multimodal models (see examples/llava)
136141
std::string mmproj = ""; // path to multimodal projector
137-
std::string image = ""; // path to an image file
142+
std::string image = ""; // path to an image file
138143
};
139144

140145
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params);
@@ -149,6 +154,12 @@ std::string gpt_random_prompt(std::mt19937 & rng);
149154

150155
void process_escapes(std::string& input);
151156

157+
//
158+
// String parsing
159+
//
160+
161+
std::string parse_samplers_input(std::string input);
162+
152163
//
153164
// Model utils
154165
//

Diff for: common/grammar-parser.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ namespace grammar_parser {
190190
pos = parse_space(pos + 1, is_nested);
191191
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
192192
if (last_sym_start == out_elements.size()) {
193-
throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
193+
throw std::runtime_error(std::string("expecting preceding item to */+/? at ") + pos);
194194
}
195195

196196
// apply transformation to previous symbol (last_sym_start to end) according to

0 commit comments

Comments
 (0)