Skip to content

Commit 299d255

Browse files
committed
llama : remove sampling from llama_context
ggml-ci
1 parent 43440c0 commit 299d255

File tree

21 files changed

+58
-124
lines changed

21 files changed

+58
-124
lines changed

Diff for: common/common.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
264264
params.kv_overrides.back().key[0] = 0;
265265
}
266266

267+
if (params.sparams.seed == LLAMA_DEFAULT_SEED) {
268+
params.sparams.seed = time(NULL);
269+
}
270+
267271
return true;
268272
}
269273

@@ -294,8 +298,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
294298

295299
if (arg == "-s" || arg == "--seed") {
296300
CHECK_ARG
297-
// TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context.
298-
params.seed = std::stoul(argv[i]);
299301
sparams.seed = std::stoul(argv[i]);
300302
return true;
301303
}
@@ -1404,7 +1406,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
14041406
options.push_back({ "*", " --verbose-prompt", "print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false" });
14051407
options.push_back({ "*", " --no-display-prompt", "don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false" });
14061408
options.push_back({ "*", "-co, --color", "colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false" });
1407-
options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", params.seed });
14081409
options.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.n_threads });
14091410
options.push_back({ "*", "-tb, --threads-batch N", "number of threads to use during batch and prompt processing (default: same as --threads)" });
14101411
options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" });
@@ -1455,6 +1456,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
14551456
" --spm-infill", "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", params.spm_infill ? "enabled" : "disabled" });
14561457

14571458
options.push_back({ "sampling" });
1459+
options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", sparams.seed });
14581460
options.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n"
14591461
"(default: %s)", sampler_type_names.c_str() });
14601462
options.push_back({ "*", " --sampling-seq SEQUENCE",
@@ -2199,7 +2201,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
21992201
cparams.n_ubatch = params.n_ubatch;
22002202
cparams.n_threads = params.n_threads;
22012203
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
2202-
cparams.seed = params.seed;
22032204
cparams.logits_all = params.logits_all;
22042205
cparams.embeddings = params.embedding;
22052206
cparams.rope_scaling_type = params.rope_scaling_type;
@@ -3210,7 +3211,6 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
32103211

32113212
fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base);
32123213
fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale);
3213-
fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
32143214
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
32153215
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
32163216
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");

Diff for: common/common.h

-2
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ enum dimre_method {
5959
};
6060

6161
struct gpt_params {
62-
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed
63-
6462
int32_t n_threads = cpu_get_num_math();
6563
int32_t n_threads_draft = -1;
6664
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)

Diff for: common/sampling.cpp

+2-13
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,10 @@
33
#include <random>
44

55
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model) {
6-
auto result = llama_sampling_init(params, llama_sampling_init(model, params.grammar.c_str(), "root"));
7-
8-
result->owned = true;
9-
10-
return result;
11-
}
12-
13-
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_sampling * smpl) {
146
struct llama_sampling_context * result = new llama_sampling_context();
157

168
result->params = params;
17-
result->owned = false;
18-
result->smpl = smpl;
9+
result->smpl = llama_sampling_init(model, params.grammar.c_str(), "root");
1910

2011
result->prev.resize(params.n_prev);
2112

@@ -27,9 +18,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
2718
}
2819

2920
void llama_sampling_free(struct llama_sampling_context * ctx) {
30-
if (ctx->owned) {
31-
llama_sampling_free(ctx->smpl);
32-
}
21+
llama_sampling_free(ctx->smpl);
3322

3423
delete ctx;
3524
}

Diff for: common/sampling.h

-3
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ struct llama_sampling_context {
7171
// mirostat sampler state
7272
float mirostat_mu;
7373

74-
bool owned;
75-
7674
llama_sampling * smpl;
7775

7876
// TODO: replace with ring-buffer
@@ -86,7 +84,6 @@ struct llama_sampling_context {
8684

8785
// Create a new sampling context instance.
8886
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model);
89-
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_sampling * smpl);
9087

9188
void llama_sampling_free(struct llama_sampling_context * ctx);
9289

Diff for: examples/batched/batched.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ int main(int argc, char ** argv) {
6464
ctx_params.n_batch = std::max(n_predict, n_parallel);
6565

6666
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
67-
llama_sampling * smpl = llama_get_sampling(ctx);
67+
llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr);
6868

6969
if (ctx == NULL) {
7070
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);

Diff for: examples/embedding/embedding.cpp

+1-7
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,7 @@ int main(int argc, char ** argv) {
6868

6969
print_build_info();
7070

71-
if (params.seed == LLAMA_DEFAULT_SEED) {
72-
params.seed = time(NULL);
73-
}
74-
75-
fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
76-
77-
std::mt19937 rng(params.seed);
71+
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
7872

7973
llama_backend_init();
8074
llama_numa_init(params.numa);

Diff for: examples/eval-callback/eval-callback.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,6 @@ int main(int argc, char ** argv) {
151151

152152
print_build_info();
153153

154-
std::mt19937 rng(params.seed);
155-
156154
llama_backend_init();
157155
llama_numa_init(params.numa);
158156

Diff for: examples/gritlm/gritlm.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,10 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
9292
return result;
9393
}
9494

95-
static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
95+
static std::string generate(llama_context * ctx, llama_sampling * smpl, const std::string & prompt, bool stream) {
9696
std::string result;
9797

9898
const llama_model * model = llama_get_model(ctx);
99-
llama_sampling * smpl = llama_get_sampling(ctx);
10099
llama_token eos_token = llama_token_eos(model);
101100

102101
llama_kv_cache_clear(ctx);
@@ -117,7 +116,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
117116
inputs.clear();
118117

119118
llama_decode(ctx, bat);
120-
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
119+
auto * logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
121120

122121
auto candidates = std::vector<llama_token_data>(llama_n_vocab(model));
123122
auto n_candidates = (int32_t)candidates.size();
@@ -173,6 +172,8 @@ int main(int argc, char * argv[]) {
173172
// create generation context
174173
llama_context * ctx = llama_new_context_with_model(model, cparams);
175174

175+
llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr);
176+
176177
// ### Embedding/Representation ###
177178
// samples taken from: https://github.com/ContextualAI/gritlm#basic
178179
{
@@ -209,9 +210,10 @@ int main(int argc, char * argv[]) {
209210
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
210211
{
211212
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
212-
std::string response = generate(ctx, prompt, true);
213+
std::string response = generate(ctx, smpl, prompt, true);
213214
}
214215

216+
llama_sampling_free(smpl);
215217
llama_free(ctx);
216218
llama_free_model(model);
217219
llama_backend_free();

Diff for: examples/infill/infill.cpp

+3-10
Original file line numberDiff line numberDiff line change
@@ -156,16 +156,9 @@ int main(int argc, char ** argv) {
156156
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
157157
}
158158

159-
LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
160-
LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
159+
print_build_info();
161160

162-
if (params.seed == LLAMA_DEFAULT_SEED) {
163-
params.seed = time(NULL);
164-
}
165-
166-
LOG_TEE("%s: seed = %u\n", __func__, params.seed);
167-
168-
std::mt19937 rng(params.seed);
161+
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
169162

170163
LOG("%s: llama backend init\n", __func__);
171164
llama_backend_init();
@@ -348,7 +341,7 @@ int main(int argc, char ** argv) {
348341

349342
std::vector<llama_token> embd;
350343

351-
ctx_sampling = llama_sampling_init(sparams, llama_get_sampling(ctx));
344+
ctx_sampling = llama_sampling_init(sparams, model);
352345

353346
while (n_remain != 0 || params.interactive) {
354347
// predict

Diff for: examples/llava/llava-cli.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
191191

192192
LOG_TEE("\n");
193193

194-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, llama_get_sampling(ctx_llava->ctx_llama));
194+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, ctx_llava->model);
195195
if (!ctx_sampling) {
196196
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
197197
exit(1);

Diff for: examples/lookahead/lookahead.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include "common.h"
22
#include "llama.h"
33

4-
#include <cmath>
54
#include <cstdio>
65
#include <string>
76
#include <vector>
@@ -118,7 +117,7 @@ int main(int argc, char ** argv) {
118117
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
119118

120119
// target model sampling context
121-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, llama_get_sampling(ctx));
120+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, model);
122121

123122
// verification n-grams
124123
std::vector<ngram_data> ngrams_cur(G);

Diff for: examples/lookup/lookup.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33
#include "common.h"
44
#include "ngram-cache.h"
55

6-
#include <cmath>
76
#include <cstdint>
87
#include <cstdio>
98
#include <fstream>
109
#include <string>
1110
#include <vector>
12-
#include <unordered_map>
1311

1412
int main(int argc, char ** argv){
1513
gpt_params params;
@@ -106,7 +104,7 @@ int main(int argc, char ** argv){
106104

107105
bool has_eos = false;
108106

109-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, llama_get_sampling(ctx));
107+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, model);
110108

111109
std::vector<llama_token> draft;
112110

Diff for: examples/main/main.cpp

+3-10
Original file line numberDiff line numberDiff line change
@@ -183,16 +183,9 @@ int main(int argc, char ** argv) {
183183
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
184184
}
185185

186-
LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
187-
LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
186+
print_build_info();
188187

189-
if (params.seed == LLAMA_DEFAULT_SEED) {
190-
params.seed = time(NULL);
191-
}
192-
193-
LOG_TEE("%s: seed = %u\n", __func__, params.seed);
194-
195-
std::mt19937 rng(params.seed);
188+
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
196189

197190
LOG("%s: llama backend init\n", __func__);
198191
llama_backend_init();
@@ -532,7 +525,7 @@ int main(int argc, char ** argv) {
532525
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
533526
}
534527

535-
ctx_sampling = llama_sampling_init(sparams, llama_get_sampling(ctx));
528+
ctx_sampling = llama_sampling_init(sparams, model);
536529
if (!ctx_sampling) {
537530
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
538531
exit(1);

Diff for: examples/passkey/passkey.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ int main(int argc, char ** argv) {
2626
return 1;
2727
}
2828

29-
srand(params.seed == LLAMA_DEFAULT_SEED ? time(NULL) : params.seed);
30-
3129
int n_junk = params.n_junk;
3230
int n_keep = params.n_keep;
3331
int n_grp = params.grp_attn_n;
@@ -85,7 +83,7 @@ int main(int argc, char ** argv) {
8583
return 1;
8684
}
8785

88-
llama_sampling * smpl = llama_get_sampling(ctx);
86+
llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr);
8987

9088
// tokenize the prompt
9189
std::vector<llama_token> tokens_list;
@@ -274,6 +272,7 @@ int main(int argc, char ** argv) {
274272

275273
llama_batch_free(batch);
276274

275+
llama_sampling_free(smpl);
277276
llama_free(ctx);
278277
llama_free_model(model);
279278

Diff for: examples/perplexity/perplexity.cpp

+1-7
Original file line numberDiff line numberDiff line change
@@ -2007,13 +2007,7 @@ int main(int argc, char ** argv) {
20072007

20082008
print_build_info();
20092009

2010-
if (params.seed == LLAMA_DEFAULT_SEED) {
2011-
params.seed = time(NULL);
2012-
}
2013-
2014-
fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
2015-
2016-
std::mt19937 rng(params.seed);
2010+
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
20172011

20182012
llama_backend_init();
20192013
llama_numa_init(params.numa);

Diff for: examples/quantize-stats/quantize-stats.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,7 @@ int main(int argc, char ** argv) {
319319
}
320320

321321
auto cparams = llama_context_default_params();
322-
cparams.n_ctx = 256;
323-
cparams.seed = 1;
322+
cparams.n_ctx = 256;
324323

325324
ctx = llama_new_context_with_model(model, cparams);
326325

Diff for: examples/save-load-state/save-load-state.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
#include <vector>
55
#include <cstdio>
6-
#include <chrono>
76

87
int main(int argc, char ** argv) {
98
gpt_params params;
@@ -37,7 +36,7 @@ int main(int argc, char ** argv) {
3736
return 1;
3837
}
3938

40-
llama_sampling * smpl = llama_get_sampling(ctx);
39+
llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr);
4140

4241
// tokenize prompt
4342
auto tokens = llama_tokenize(ctx, params.prompt, true);
@@ -97,7 +96,7 @@ int main(int argc, char ** argv) {
9796
// make new context
9897
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
9998

100-
llama_sampling * smpl2 = llama_get_sampling(ctx2);
99+
llama_sampling * smpl2 = llama_sampling_init(model, nullptr, nullptr);
101100

102101
printf("\nsecond run: %s", params.prompt.c_str());
103102

@@ -162,7 +161,7 @@ int main(int argc, char ** argv) {
162161
// make new context
163162
auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
164163

165-
llama_sampling * smpl3 = llama_get_sampling(ctx3);
164+
llama_sampling * smpl3 = llama_sampling_init(model, nullptr, nullptr);
166165

167166
printf("\nsingle seq run: %s", params.prompt.c_str());
168167

@@ -245,6 +244,10 @@ int main(int argc, char ** argv) {
245244

246245
printf("\n");
247246

247+
llama_sampling_free(smpl);
248+
llama_sampling_free(smpl2);
249+
llama_sampling_free(smpl3);
250+
248251
llama_free(ctx3);
249252
llama_free_model(model);
250253

0 commit comments

Comments
 (0)