Skip to content

Commit 62984db

Browse files
committed
cont : move samplers to llama lib
ggml-ci
1 parent 861ad6f commit 62984db

File tree

6 files changed

+83
-81
lines changed

6 files changed

+83
-81
lines changed

Diff for: common/common.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
615615
}
616616
if (arg == "--typical") {
617617
CHECK_ARG
618-
sparams.typical_p = std::stof(argv[i]);
618+
sparams.typ_p = std::stof(argv[i]);
619619
return true;
620620
}
621621
if (arg == "--repeat-last-n") {
@@ -1532,12 +1532,12 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
15321532
"simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() });
15331533
options.push_back({ "*", " --ignore-eos", "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" });
15341534
options.push_back({ "*", " --penalize-nl", "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" });
1535-
options.push_back({ "*", " --temp N", "temperature (default: %.1f)", (double)sparams.temp });
1535+
options.push_back({ "*", " --temp T", "temperature (default: %.1f)", (double)sparams.temp });
15361536
options.push_back({ "*", " --top-k N", "top-k sampling (default: %d, 0 = disabled)", sparams.top_k });
1537-
options.push_back({ "*", " --top-p N", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p });
1538-
options.push_back({ "*", " --min-p N", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p });
1539-
options.push_back({ "*", " --tfs N", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z });
1540-
options.push_back({ "*", " --typical N", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typical_p });
1537+
options.push_back({ "*", " --top-p P", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p });
1538+
options.push_back({ "*", " --min-p P", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p });
1539+
options.push_back({ "*", " --tfs P", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z });
1540+
options.push_back({ "*", " --typical P", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typ_p });
15411541
options.push_back({ "*", " --repeat-last-n N", "last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", sparams.penalty_last_n });
15421542
options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat });
15431543
options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present });
@@ -3316,7 +3316,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
33163316
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
33173317
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
33183318
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
3319-
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
3319+
fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
33203320
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
33213321
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
33223322
}

Diff for: common/sampling.cpp

+7-68
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_model * m
1818
lparams.top_p = params.top_p;
1919
lparams.min_p = params.min_p;
2020
lparams.tfs_z = params.tfs_z;
21-
lparams.typical_p = params.typical_p;
21+
lparams.typ_p = params.typ_p;
2222
lparams.temp = params.temp;
2323
lparams.dynatemp_range = params.dynatemp_range;
2424
lparams.dynatemp_exponent = params.dynatemp_exponent;
@@ -94,7 +94,7 @@ std::string llama_sampling_print(const gpt_sampling_params & params) {
9494
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
9595
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
9696
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
97-
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
97+
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typ_p, params.temp,
9898
params.mirostat, params.mirostat_eta, params.mirostat_tau);
9999

100100
return std::string(result);
@@ -132,7 +132,7 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
132132
switch (sampler_type) {
133133
case LLAMA_SAMPLER_TYPE_TOP_K: return "top_k";
134134
case LLAMA_SAMPLER_TYPE_TFS_Z: return "tfs_z";
135-
case LLAMA_SAMPLER_TYPE_TYPICAL_P: return "typical_p";
135+
case LLAMA_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
136136
case LLAMA_SAMPLER_TYPE_TOP_P: return "top_p";
137137
case LLAMA_SAMPLER_TYPE_MIN_P: return "min_p";
138138
case LLAMA_SAMPLER_TYPE_TEMPERATURE: return "temperature";
@@ -144,7 +144,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
144144
std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
145145
{ "top_k", LLAMA_SAMPLER_TYPE_TOP_K },
146146
{ "top_p", LLAMA_SAMPLER_TYPE_TOP_P },
147-
{ "typical_p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
147+
{ "typ_p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
148148
{ "min_p", LLAMA_SAMPLER_TYPE_MIN_P },
149149
{ "tfs_z", LLAMA_SAMPLER_TYPE_TFS_Z },
150150
{ "temperature", LLAMA_SAMPLER_TYPE_TEMPERATURE },
@@ -158,6 +158,8 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
158158
{ "nucleus", LLAMA_SAMPLER_TYPE_TOP_P },
159159
{ "typical-p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
160160
{ "typical", LLAMA_SAMPLER_TYPE_TYPICAL_P },
161+
{ "typ-p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
162+
{ "typ", LLAMA_SAMPLER_TYPE_TYPICAL_P },
161163
{ "min-p", LLAMA_SAMPLER_TYPE_MIN_P },
162164
{ "tfs-z", LLAMA_SAMPLER_TYPE_TFS_Z },
163165
{ "tfs", LLAMA_SAMPLER_TYPE_TFS_Z },
@@ -205,29 +207,6 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
205207
return sampler_types;
206208
}
207209

208-
// no reasons to expose this function in header
209-
static void sampler_queue(
210-
struct llama_sampling_context * ctx_sampling,
211-
struct llama_token_data_array * cur_p) {
212-
llama_sampling * smpl = ctx_sampling->smpl;
213-
214-
const gpt_sampling_params & params = ctx_sampling->params;
215-
216-
const std::vector<llama_sampler_type> & samplers = params.samplers;
217-
218-
for (const auto & sampler : samplers) {
219-
switch (sampler) {
220-
case LLAMA_SAMPLER_TYPE_TOP_K: llama_sampling_top_k (smpl, cur_p); break;
221-
case LLAMA_SAMPLER_TYPE_TFS_Z: llama_sampling_tail_free(smpl, cur_p); break;
222-
case LLAMA_SAMPLER_TYPE_TYPICAL_P: llama_sampling_typical (smpl, cur_p); break;
223-
case LLAMA_SAMPLER_TYPE_TOP_P: llama_sampling_top_p (smpl, cur_p); break;
224-
case LLAMA_SAMPLER_TYPE_MIN_P: llama_sampling_min_p (smpl, cur_p); break;
225-
case LLAMA_SAMPLER_TYPE_TEMPERATURE: llama_sampling_temp (smpl, cur_p); break;
226-
default : break;
227-
}
228-
}
229-
}
230-
231210
void llama_sampling_prepare(
232211
struct llama_sampling_context * ctx_sampling,
233212
struct llama_context * ctx_main,
@@ -238,47 +217,7 @@ void llama_sampling_prepare(
238217
static llama_token llama_sampling_sample(
239218
struct llama_sampling_context * ctx_sampling,
240219
struct llama_token_data_array * cur_p) {
241-
llama_sampling * smpl = ctx_sampling->smpl;
242-
243-
const gpt_sampling_params & params = ctx_sampling->params;
244-
245-
const float temp = params.temp;
246-
const int mirostat = params.mirostat;
247-
248-
llama_token id = 0;
249-
250-
if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) {
251-
// greedy sampling, with probs
252-
llama_sampling_softmax(smpl, cur_p);
253-
id = cur_p->data[0].id;
254-
} else if (temp == 0.0f) {
255-
// greedy sampling, no probs
256-
id = llama_sampling_sample_greedy(smpl, cur_p);
257-
} else {
258-
if (mirostat != 0) {
259-
llama_sampling_temp(smpl, cur_p);
260-
id = llama_sampling_sample_mirostat(smpl, cur_p);
261-
} else {
262-
sampler_queue(ctx_sampling, cur_p);
263-
264-
id = llama_sampling_sample_dist(smpl, cur_p);
265-
266-
//{
267-
// const int n_top = 10;
268-
// LOG("top %d candidates:\n", n_top);
269-
270-
// for (int i = 0; i < n_top; i++) {
271-
// const llama_token id = cur_p.data[i].id;
272-
// (void)id; // To avoid a warning that id is unused when logging is disabled.
273-
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
274-
// }
275-
//}
276-
277-
//LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(smpl, id).c_str());
278-
}
279-
}
280-
281-
return id;
220+
return llama_sampling_sample(ctx_sampling->smpl, cur_p);
282221
}
283222

284223
llama_token llama_sampling_sample(

Diff for: common/sampling.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ typedef struct gpt_sampling_params {
1616
float top_p = 0.95f; // 1.0 = disabled
1717
float min_p = 0.05f; // 0.0 = disabled
1818
float tfs_z = 1.00f; // 1.0 = disabled
19-
float typical_p = 1.00f; // 1.0 = disabled
19+
float typ_p = 1.00f; // typical_p, 1.0 = disabled
2020
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
2121
float dynatemp_range = 0.00f; // 0.0 = disabled
2222
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler

Diff for: examples/server/server.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ struct server_context {
914914
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
915915
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
916916
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
917-
slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
917+
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
918918
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
919919
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
920920
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
@@ -1283,7 +1283,7 @@ struct server_context {
12831283
{"top_p", slot.sparams.top_p},
12841284
{"min_p", slot.sparams.min_p},
12851285
{"tfs_z", slot.sparams.tfs_z},
1286-
{"typical_p", slot.sparams.typical_p},
1286+
{"typical_p", slot.sparams.typ_p},
12871287
{"repeat_last_n", slot.sparams.penalty_last_n},
12881288
{"repeat_penalty", slot.sparams.penalty_repeat},
12891289
{"presence_penalty", slot.sparams.penalty_present},

Diff for: include/llama.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ extern "C" {
388388
float top_p; // 1.0 = disabled
389389
float min_p; // 0.0 = disabled
390390
float tfs_z; // 1.0 = disabled
391-
float typical_p; // 1.0 = disabled
391+
float typ_p; // typical_p, 1.0 = disabled
392392
float temp; // <= 0.0 to sample greedily, 0.0 to not output probabilities
393393
float dynatemp_range; // 0.0 = disabled
394394
float dynatemp_exponent; // controls how entropy maps to temperature in dynamic temperature sampler
@@ -1106,6 +1106,11 @@ extern "C" {
11061106
struct llama_sampling * smpl,
11071107
llama_token_data_array * candidates);
11081108

1109+
/// @details Sample a token using the configured samplers.
1110+
LLAMA_API llama_token llama_sampling_sample(
1111+
struct llama_sampling * smpl,
1112+
llama_token_data_array * candidates);
1113+
11091114
/// @details Accepts the sampled token into the sampling context
11101115
LLAMA_API void llama_sampling_accept(
11111116
struct llama_sampling * smpl,

Diff for: src/llama.cpp

+60-2
Original file line numberDiff line numberDiff line change
@@ -17418,7 +17418,7 @@ struct llama_sampling_params llama_sampling_default_params() {
1741817418
/*.top_p =*/ 0.95f,
1741917419
/*.min_p =*/ 0.05f,
1742017420
/*.tfs_z =*/ 1.00f,
17421-
/*.typical_p =*/ 1.00f,
17421+
/*.typ_p =*/ 1.00f,
1742217422
/*.temp =*/ 0.80f,
1742317423
/*.dynatemp_range =*/ 0.00f,
1742417424
/*.dynatemp_exponent =*/ 1.00f,
@@ -20169,7 +20169,7 @@ void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_arr
2016920169
void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2017020170
time_meas tm(smpl->t_sample_us);
2017120171

20172-
llama_sampling_typical_impl(candidates, smpl->params.typical_p, smpl->params.min_keep);
20172+
llama_sampling_typical_impl(candidates, smpl->params.typ_p, smpl->params.min_keep);
2017320173
}
2017420174

2017520175
void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) {
@@ -20271,6 +20271,64 @@ llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token
2027120271
return res;
2027220272
}
2027320273

20274+
llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) {
20275+
time_meas tm(smpl->t_sample_us);
20276+
20277+
const auto & params = smpl->params;
20278+
20279+
const float temp = params.temp;
20280+
const int mirostat = params.mirostat;
20281+
20282+
auto & cur_p = candidates;
20283+
20284+
llama_token res = 0;
20285+
20286+
if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) {
20287+
// greedy sampling, with probs
20288+
llama_sampling_softmax_impl(cur_p);
20289+
res = cur_p->data[0].id;
20290+
} else if (temp == 0.0f) {
20291+
// greedy sampling, no probs
20292+
res = llama_sampling_sample_greedy(smpl, cur_p);
20293+
} else {
20294+
if (mirostat != 0) {
20295+
llama_sampling_temp(smpl, cur_p);
20296+
res = llama_sampling_sample_mirostat(smpl, cur_p);
20297+
} else {
20298+
for (const auto & sampler : smpl->samplers) {
20299+
switch (sampler) {
20300+
case LLAMA_SAMPLER_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break;
20301+
case LLAMA_SAMPLER_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break;
20302+
case LLAMA_SAMPLER_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break;
20303+
case LLAMA_SAMPLER_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break;
20304+
case LLAMA_SAMPLER_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break;
20305+
case LLAMA_SAMPLER_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break;
20306+
default : break;
20307+
}
20308+
}
20309+
20310+
res = llama_sampling_sample_dist(smpl, cur_p);
20311+
20312+
//{
20313+
// const int n_top = 10;
20314+
// LOG("top %d candidates:\n", n_top);
20315+
20316+
// for (int i = 0; i < n_top; i++) {
20317+
// const llama_token id = cur_p.data[i].id;
20318+
// (void)id; // To avoid a warning that id is unused when logging is disabled.
20319+
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
20320+
// }
20321+
//}
20322+
20323+
//LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str());
20324+
}
20325+
}
20326+
20327+
smpl->n_sample++;
20328+
20329+
return res;
20330+
}
20331+
2027420332
void llama_sampling_accept(
2027520333
struct llama_sampling * smpl,
2027620334
llama_token token,

0 commit comments

Comments
 (0)