Skip to content

Commit c21dc41

Browse files
committed
class LlamaSampler: append add_xtc(), add_top_n_sigma() and add_dry()
Remove Tail-Free sampling Add TopN-Sigma/XTC/DRY samplers code into sampler
1 parent 37ae873 commit c21dc41

File tree

6 files changed

+506
-73
lines changed

6 files changed

+506
-73
lines changed

examples/low_level_api/common.py

+60-9
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ class GptParams:
2121
ignore_eos: bool = False
2222
logit_bias: dict[int, float] = field(default_factory=dict)
2323
top_k: int = 40
24+
top_n_sigma: float = -1.00
2425
top_p: float = 0.95
25-
tfs_z: float = 1.00
26+
2627
typical_p: float = 1.00
2728
temp: float = 0.80
2829
repeat_penalty: float = 1.10
@@ -32,7 +33,13 @@ class GptParams:
3233
mirostat: int = 0
3334
mirostat_tau: float = 5.0
3435
mirostat_eta: float = 0.1
35-
36+
xtc_threshold: float = 0.1
37+
xtc_probability: float = 0.0
38+
dry_multiplier: float = 0.0
39+
dry_base: float = 1.75
40+
dry_allowed_length: int = 2
41+
dry_penalty_last_n:int = 0
42+
dry_seq_breakers: list[str] = ["\n", ":", "\"", "*"]
3643
model: str = "./models/llama-7B/ggml-model.bin"
3744
prompt: str = ""
3845
path_session: str = ""
@@ -147,14 +154,10 @@ def gpt_params_parse(argv=None):
147154
"--top_k", type=int, default=40, help="top-k sampling", dest="top_k"
148155
)
149156
parser.add_argument(
150-
"--top_p", type=float, default=0.95, help="top-p samplin", dest="top_p"
157+
"--top_n_sigma", type=int, default=40, help="top-n-sigma sampling", dest="top_n_sigma"
151158
)
152159
parser.add_argument(
153-
"--tfs",
154-
type=float,
155-
default=1.0,
156-
help="tail free sampling, parameter z (1.0 = disabled)",
157-
dest="tfs_z",
160+
"--top_p", type=float, default=0.95, help="top-p samplin", dest="top_p"
158161
)
159162
parser.add_argument(
160163
"--temp", type=float, default=0.80, help="temperature", dest="temp"
@@ -178,7 +181,7 @@ def gpt_params_parse(argv=None):
178181
type=float,
179182
default=0.0,
180183
help="repeat alpha frequency penalty (0.0 = disabled)",
181-
dest="tfs_z",
184+
dest="frequency_penalty",
182185
)
183186
parser.add_argument(
184187
"--presence_penalty",
@@ -209,6 +212,54 @@ def gpt_params_parse(argv=None):
209212
dest="mirostat_eta",
210213
)
211214

215+
parser.add_argument(
216+
"--xtc_threshold",
217+
type=float,
218+
default=0.1,
219+
help="Sets a minimum probability threshold for tokens to be removed (default: 0.1)",
220+
dest="xtc_threshold",
221+
)
222+
223+
parser.add_argument(
224+
"--xtc_probability",
225+
type=float,
226+
default=0.0,
227+
help="Sets the chance for token removal (checked once on sampler start) (default: 0.0)",
228+
dest="xtc_probability",
229+
)
230+
231+
parser.add_argument(
232+
"--dry_multiplier",
233+
type=float,
234+
default=0.0,
235+
help="Set the DRY repetition penalty multiplier. Default is 0.0, which disables DRY.",
236+
dest="dry_multiplier",
237+
)
238+
239+
parser.add_argument(
240+
"--dry_base",
241+
type=float,
242+
default=1.75,
243+
help="Set the DRY repetition penalty base value. Default is 1.75",
244+
dest="dry_base",
245+
)
246+
247+
parser.add_argument(
248+
"--dry_allowed_length",
249+
type=int,
250+
default=2,
251+
help="Tokens that extend repetition beyond this receive exponentially increasing penalty. Default is 2",
252+
dest="dry_allowed_length",
253+
)
254+
255+
parser.add_argument(
256+
"--dry_penalty_last_n",
257+
type=int,
258+
default=0,
259+
help="How many tokens to scan for repetitions. Default is 0, where 0 is disabled and -1 is context size",
260+
dest="dry_penalty_last_n",
261+
)
262+
212263
parser.add_argument(
213264
"-m",
214265
"--model",

examples/low_level_api/low_level_api_chat_cpp.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,23 @@ def __init__(self, params: GptParams) -> None:
275275
presence_penalty = {self.params.presence_penalty},\
276276
frequency_penalty = {self.params.frequency_penalty},\
277277
top_k = {self.params.top_k},\
278-
tfs_z = {self.params.tfs_z},\
278+
top_n_sigma = {self.params.top_n_sigma},\
279279
top_p = {self.params.top_p},\
280280
typical_p = {self.params.typical_p},\
281281
temp = {self.params.temp},\
282282
mirostat = {self.params.mirostat},\
283283
mirostat_lr = {self.params.mirostat_eta},\
284284
mirostat_ent = {self.params.mirostat_tau},\
285285
286+
xtc_threshold = {self.params.xtc_threshold},\
287+
xtc_probability = {self.params.xtc_probability},\
288+
289+
dry_multiplier = {self.params.dry_multiplier},\
290+
dry_base = {self.params.dry_base},\
291+
dry_allowed_length = {self.params.dry_allowed_length},\
292+
dry_penalty_last_n = {self.params.dry_penalty_last_n},\
293+
dry_seq_breakers = {self.params.dry_seq_breakers},\
294+
286295
generate: n_ctx = {self.n_ctx},\
287296
n_batch = {self.params.n_batch},\
288297
n_predict = {self.params.n_predict},\
@@ -454,7 +463,7 @@ def generate(self):
454463
_arr = (llama_cpp.llama_token * last_n_repeat)(
455464
*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat :]
456465
)
457-
llama_cpp.llama_sample_repetition_penalties(
466+
llama_cpp.llama_sampler_init_penalties(
458467
ctx=self.ctx,
459468
candidates=candidates_p,
460469
last_tokens_data=_arr,
@@ -474,15 +483,15 @@ def generate(self):
474483

475484
if self.params.temp <= 0:
476485
# Greedy sampling
477-
id = llama_cpp.llama_sample_token_greedy(self.ctx, candidates_p)
486+
id = llama_cpp.llama_sampler_init_greedy(self.ctx, candidates_p)
478487
else:
479488
if self.params.mirostat == 1:
480489
mirostat_mu = 2.0 * self.params.mirostat_tau
481490
mirostat_m = 100
482-
llama_cpp.llama_sample_temperature(
491+
llama_cpp.llama_sampler_init_temp(
483492
self.ctx, candidates_p, llama_cpp.c_float(self.params.temp)
484493
)
485-
id = llama_cpp.llama_sample_token_mirostat(
494+
id = llama_cpp.llama_sampler_init_mirostat(
486495
self.ctx,
487496
candidates_p,
488497
llama_cpp.c_float(self.params.mirostat_tau),
@@ -495,7 +504,7 @@ def generate(self):
495504
llama_cpp.llama_sample_temperature(
496505
self.ctx, candidates_p, llama_cpp.c_float(self.params.temp)
497506
)
498-
id = llama_cpp.llama_sample_token_mirostat_v2(
507+
id = llama_cpp.llama_sampler_init_mirostat_v2(
499508
self.ctx,
500509
candidates_p,
501510
llama_cpp.c_float(self.params.mirostat_tau),
@@ -504,31 +513,31 @@ def generate(self):
504513
)
505514
else:
506515
# Temperature sampling
507-
llama_cpp.llama_sample_top_k(
516+
llama_cpp.llama_sampler_init_top_k(
508517
self.ctx,
509518
candidates_p,
510519
top_k,
511520
min_keep=llama_cpp.c_size_t(1),
512521
)
513-
llama_cpp.llama_sample_tail_free(
522+
llama_cpp.llama_sampler_init_top_n_sigma(
514523
self.ctx,
515524
candidates_p,
516-
llama_cpp.c_float(self.params.tfs_z),
525+
llama_cpp.c_float(self.params.top_n_sigma),
517526
min_keep=llama_cpp.c_size_t(1),
518527
)
519-
llama_cpp.llama_sample_typical(
528+
llama_cpp.llama_sampler_init_typical(
520529
self.ctx,
521530
candidates_p,
522531
llama_cpp.c_float(self.params.typical_p),
523532
min_keep=llama_cpp.c_size_t(1),
524533
)
525-
llama_cpp.llama_sample_top_p(
534+
llama_cpp.llama_sampler_init_top_p(
526535
self.ctx,
527536
candidates_p,
528537
llama_cpp.c_float(self.params.top_p),
529538
min_keep=llama_cpp.c_size_t(1),
530539
)
531-
llama_cpp.llama_sample_temperature(
540+
llama_cpp.llama_sampler_init_temp(
532541
self.ctx, candidates_p, llama_cpp.c_float(self.params.temp)
533542
)
534543
id = llama_cpp.llama_sample_token(self.ctx, candidates_p)

0 commit comments

Comments
 (0)