Skip to content

Commit eda526d

Browse files
author
llmixer
committed
Added DRY and XTC samplers
1 parent 7403e00 commit eda526d

File tree

3 files changed

+148
-0
lines changed

3 files changed

+148
-0
lines changed

Diff for: llama_cpp/_internals.py

+16
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,22 @@ def add_mirostat_v2(self, seed: int, tau: float, eta: float):
800800
sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta)
801801
self._add_sampler(sampler)
802802

803+
def add_xtc(self, probability: float, threshold: float, min_keep: int, seed: int):
804+
sampler = llama_cpp.llama_sampler_init_xtc(probability, threshold, min_keep, seed)
805+
self._add_sampler(sampler)
806+
807+
def add_dry(self, model: LlamaModel, multiplier: float, base: float,
808+
allowed_length: int, penalty_last_n: int, seq_breakers: list[str] = []):
809+
810+
# Convert Python strings to bytes
811+
seq_breakers_bytes = [s.encode('utf-8') for s in seq_breakers]
812+
# Create array of char*
813+
arr = (ctypes.c_char_p * len(seq_breakers_bytes))(*seq_breakers_bytes)
814+
sampler = llama_cpp.llama_sampler_init_dry(model.model, multiplier, base,
815+
allowed_length, penalty_last_n,
816+
arr, len(seq_breakers))
817+
self._add_sampler(sampler)
818+
803819
def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
804820
sampler = llama_cpp.llama_sampler_init_grammar(
805821
model.model, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")

Diff for: llama_cpp/llama.py

+100
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,13 @@ def _init_sampler(
677677
mirostat_mode: int = 0,
678678
mirostat_eta: float = 0.1,
679679
mirostat_tau: float = 5.0,
680+
xtc_probability: float = 0.0,
681+
xtc_threshold: float = 0.1,
682+
dry_multiplier: float = 0.0,
683+
dry_allowed_length: int = 2,
684+
dry_base: float = 1.75,
685+
dry_range: int = 0,
686+
dry_seq_breakers: list[str] = [],
680687
penalize_nl: bool = True,
681688
logits_processor: Optional[LogitsProcessorList] = None,
682689
grammar: Optional[LlamaGrammar] = None,
@@ -744,13 +751,15 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
744751
else:
745752
n_probs = 0
746753
min_keep = max(1, n_probs)
754+
sampler.add_dry(self._model, dry_multiplier, dry_base, dry_allowed_length, dry_range, dry_seq_breakers)
747755
sampler.add_top_k(top_k)
748756
sampler.add_tail_free(tfs_z, min_keep)
749757
sampler.add_typical(typical_p, min_keep)
750758
sampler.add_top_p(top_p, min_keep)
751759
sampler.add_min_p(min_p, min_keep)
752760
sampler.add_temp(temp)
753761
sampler.add_dist(self._seed)
762+
sampler.add_xtc(xtc_probability, xtc_threshold, min_keep, self._seed)
754763
return sampler
755764

756765
def sample(
@@ -767,6 +776,13 @@ def sample(
767776
mirostat_mode: int = 0,
768777
mirostat_eta: float = 0.1,
769778
mirostat_tau: float = 5.0,
779+
xtc_probability: float = 0.0,
780+
xtc_threshold: float = 0.1,
781+
dry_multiplier: float = 0.0,
782+
dry_allowed_length: int = 2,
783+
dry_base: float = 1.75,
784+
dry_range: int = 0,
785+
dry_seq_breakers: list[str] = [],
770786
penalize_nl: bool = True,
771787
logits_processor: Optional[LogitsProcessorList] = None,
772788
grammar: Optional[LlamaGrammar] = None,
@@ -802,6 +818,13 @@ def sample(
802818
mirostat_mode=mirostat_mode,
803819
mirostat_tau=mirostat_tau,
804820
mirostat_eta=mirostat_eta,
821+
xtc_probability=xtc_probability,
822+
xtc_threshold=xtc_threshold,
823+
dry_multiplier=dry_multiplier,
824+
dry_allowed_length=dry_allowed_length,
825+
dry_base=dry_base,
826+
dry_range=dry_range,
827+
dry_seq_breakers=dry_seq_breakers,
805828
penalize_nl=penalize_nl,
806829
logits_processor=logits_processor,
807830
grammar=grammar,
@@ -831,6 +854,13 @@ def generate(
831854
mirostat_mode: int = 0,
832855
mirostat_tau: float = 5.0,
833856
mirostat_eta: float = 0.1,
857+
xtc_probability: float = 0.0,
858+
xtc_threshold: float = 0.1,
859+
dry_multiplier: float = 0.0,
860+
dry_allowed_length: int = 2,
861+
dry_base: float = 1.75,
862+
dry_range: int = 0,
863+
dry_seq_breakers: list[str] = [],
834864
penalize_nl: bool = True,
835865
logits_processor: Optional[LogitsProcessorList] = None,
836866
stopping_criteria: Optional[StoppingCriteriaList] = None,
@@ -870,6 +900,13 @@ def generate(
870900
mirostat_mode=mirostat_mode,
871901
mirostat_tau=mirostat_tau,
872902
mirostat_eta=mirostat_eta,
903+
xtc_probability=xtc_probability,
904+
xtc_threshold=xtc_threshold,
905+
dry_multiplier=dry_multiplier,
906+
dry_allowed_length=dry_allowed_length,
907+
dry_base=dry_base,
908+
dry_range=dry_range,
909+
dry_seq_breakers=dry_seq_breakers,
873910
penalize_nl=penalize_nl,
874911
logits_processor=logits_processor,
875912
grammar=grammar,
@@ -922,6 +959,13 @@ def generate(
922959
mirostat_mode=mirostat_mode,
923960
mirostat_tau=mirostat_tau,
924961
mirostat_eta=mirostat_eta,
962+
xtc_probability=xtc_probability,
963+
xtc_threshold=xtc_threshold,
964+
dry_multiplier=dry_multiplier,
965+
dry_allowed_length=dry_allowed_length,
966+
dry_base=dry_base,
967+
dry_range=dry_range,
968+
dry_seq_breakers=dry_seq_breakers,
925969
logits_processor=logits_processor,
926970
grammar=grammar,
927971
penalize_nl=penalize_nl,
@@ -1138,6 +1182,13 @@ def _create_completion(
11381182
mirostat_mode: int = 0,
11391183
mirostat_tau: float = 5.0,
11401184
mirostat_eta: float = 0.1,
1185+
xtc_probability: float = 0.0,
1186+
xtc_threshold: float = 0.1,
1187+
dry_multiplier: float = 0.0,
1188+
dry_allowed_length: int = 2,
1189+
dry_base: float = 1.75,
1190+
dry_range: int = 0,
1191+
dry_seq_breakers: list[str] = [],
11411192
model: Optional[str] = None,
11421193
stopping_criteria: Optional[StoppingCriteriaList] = None,
11431194
logits_processor: Optional[LogitsProcessorList] = None,
@@ -1326,6 +1377,13 @@ def logit_bias_processor(
13261377
mirostat_mode=mirostat_mode,
13271378
mirostat_tau=mirostat_tau,
13281379
mirostat_eta=mirostat_eta,
1380+
xtc_probability=xtc_probability,
1381+
xtc_threshold=xtc_threshold,
1382+
dry_multiplier=dry_multiplier,
1383+
dry_allowed_length=dry_allowed_length,
1384+
dry_base=dry_base,
1385+
dry_range=dry_range,
1386+
dry_seq_breakers=dry_seq_breakers,
13291387
frequency_penalty=frequency_penalty,
13301388
presence_penalty=presence_penalty,
13311389
repeat_penalty=repeat_penalty,
@@ -1758,6 +1816,13 @@ def create_completion(
17581816
mirostat_mode: int = 0,
17591817
mirostat_tau: float = 5.0,
17601818
mirostat_eta: float = 0.1,
1819+
xtc_probability: float = 0.0,
1820+
xtc_threshold: float = 0.1,
1821+
dry_multiplier: float = 0.0,
1822+
dry_allowed_length: int = 2,
1823+
dry_base: float = 1.75,
1824+
dry_range: int = 0,
1825+
dry_seq_breakers: list[str] = [],
17611826
model: Optional[str] = None,
17621827
stopping_criteria: Optional[StoppingCriteriaList] = None,
17631828
logits_processor: Optional[LogitsProcessorList] = None,
@@ -1821,6 +1886,13 @@ def create_completion(
18211886
mirostat_mode=mirostat_mode,
18221887
mirostat_tau=mirostat_tau,
18231888
mirostat_eta=mirostat_eta,
1889+
xtc_probability=xtc_probability,
1890+
xtc_threshold=xtc_threshold,
1891+
dry_multiplier=dry_multiplier,
1892+
dry_allowed_length=dry_allowed_length,
1893+
dry_base=dry_base,
1894+
dry_range=dry_range,
1895+
dry_seq_breakers=dry_seq_breakers,
18241896
model=model,
18251897
stopping_criteria=stopping_criteria,
18261898
logits_processor=logits_processor,
@@ -1855,6 +1927,13 @@ def __call__(
18551927
mirostat_mode: int = 0,
18561928
mirostat_tau: float = 5.0,
18571929
mirostat_eta: float = 0.1,
1930+
xtc_probability: float = 0.0,
1931+
xtc_threshold: float = 0.1,
1932+
dry_multiplier: float = 0.0,
1933+
dry_allowed_length: int = 2,
1934+
dry_base: float = 1.75,
1935+
dry_range: int = 0,
1936+
dry_seq_breakers: list[str] = [],
18581937
model: Optional[str] = None,
18591938
stopping_criteria: Optional[StoppingCriteriaList] = None,
18601939
logits_processor: Optional[LogitsProcessorList] = None,
@@ -1918,6 +1997,13 @@ def __call__(
19181997
mirostat_mode=mirostat_mode,
19191998
mirostat_tau=mirostat_tau,
19201999
mirostat_eta=mirostat_eta,
2000+
xtc_probability=xtc_probability,
2001+
xtc_threshold=xtc_threshold,
2002+
dry_multiplier=dry_multiplier,
2003+
dry_allowed_length=dry_allowed_length,
2004+
dry_base=dry_base,
2005+
dry_range=dry_range,
2006+
dry_seq_breakers=dry_seq_breakers,
19212007
model=model,
19222008
stopping_criteria=stopping_criteria,
19232009
logits_processor=logits_processor,
@@ -1949,6 +2035,13 @@ def create_chat_completion(
19492035
mirostat_mode: int = 0,
19502036
mirostat_tau: float = 5.0,
19512037
mirostat_eta: float = 0.1,
2038+
xtc_probability: float = 0.0,
2039+
xtc_threshold: float = 0.1,
2040+
dry_multiplier: float = 0.0,
2041+
dry_allowed_length: int = 2,
2042+
dry_base: float = 1.75,
2043+
dry_range: int = 0,
2044+
dry_seq_breakers: list[str] = [],
19522045
model: Optional[str] = None,
19532046
logits_processor: Optional[LogitsProcessorList] = None,
19542047
grammar: Optional[LlamaGrammar] = None,
@@ -2022,6 +2115,13 @@ def create_chat_completion(
20222115
mirostat_mode=mirostat_mode,
20232116
mirostat_tau=mirostat_tau,
20242117
mirostat_eta=mirostat_eta,
2118+
xtc_probability=xtc_probability,
2119+
xtc_threshold=xtc_threshold,
2120+
dry_multiplier=dry_multiplier,
2121+
dry_allowed_length=dry_allowed_length,
2122+
dry_base=dry_base,
2123+
dry_range=dry_range,
2124+
dry_seq_breakers=dry_seq_breakers,
20252125
model=model,
20262126
logits_processor=logits_processor,
20272127
grammar=grammar,

Diff for: llama_cpp/llama_cpp.py

+32
Original file line numberDiff line numberDiff line change
@@ -3244,6 +3244,38 @@ def llama_sampler_init_xtc(
32443244
) -> llama_sampler_p:
32453245
...
32463246

3247+
# LLAMA_API struct llama_sampler * llama_sampler_init_dry(
3248+
# const struct llama_model * model,
3249+
# float dry_multiplier,
3250+
# float dry_base,
3251+
# int32_t dry_allowed_length,
3252+
# int32_t dry_penalty_last_n,
3253+
# const char ** seq_breakers,
3254+
# size_t num_breakers);
3255+
@ctypes_function(
3256+
"llama_sampler_init_dry",
3257+
[
3258+
llama_model_p_ctypes,
3259+
ctypes.c_float,
3260+
ctypes.c_float,
3261+
ctypes.c_int32,
3262+
ctypes.c_int32,
3263+
ctypes.POINTER(ctypes.c_char_p),
3264+
ctypes.c_size_t
3265+
],
3266+
llama_sampler_p_ctypes,
3267+
)
3268+
def llama_sampler_init_dry(
3269+
model: llama_model_p,
3270+
dry_multiplier: float,
3271+
dry_base: float,
3272+
dry_allowed_length: int,
3273+
dry_penalty_last_n: int,
3274+
seq_breakers: list[str],
3275+
num_breakers: int,
3276+
) -> llama_sampler_p:
3277+
...
3278+
32473279

32483280
# /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
32493281
# /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.

0 commit comments

Comments
 (0)