@@ -677,6 +677,13 @@ def _init_sampler(
677
677
mirostat_mode : int = 0 ,
678
678
mirostat_eta : float = 0.1 ,
679
679
mirostat_tau : float = 5.0 ,
680
+ xtc_probability : float = 0.0 ,
681
+ xtc_threshold : float = 0.1 ,
682
+ dry_multiplier : float = 0.0 ,
683
+ dry_allowed_length : int = 2 ,
684
+ dry_base : float = 1.75 ,
685
+ dry_range : int = 0 ,
686
+ dry_seq_breakers : list [str ] = [],
680
687
penalize_nl : bool = True ,
681
688
logits_processor : Optional [LogitsProcessorList ] = None ,
682
689
grammar : Optional [LlamaGrammar ] = None ,
@@ -744,13 +751,15 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
744
751
else :
745
752
n_probs = 0
746
753
min_keep = max (1 , n_probs )
754
+ sampler .add_dry (self ._model , dry_multiplier , dry_base , dry_allowed_length , dry_range , dry_seq_breakers )
747
755
sampler .add_top_k (top_k )
748
756
sampler .add_tail_free (tfs_z , min_keep )
749
757
sampler .add_typical (typical_p , min_keep )
750
758
sampler .add_top_p (top_p , min_keep )
751
759
sampler .add_min_p (min_p , min_keep )
752
760
sampler .add_temp (temp )
753
761
sampler .add_dist (self ._seed )
762
+ sampler .add_xtc (xtc_probability , xtc_threshold , min_keep , self ._seed )
754
763
return sampler
755
764
756
765
def sample (
@@ -767,6 +776,13 @@ def sample(
767
776
mirostat_mode : int = 0 ,
768
777
mirostat_eta : float = 0.1 ,
769
778
mirostat_tau : float = 5.0 ,
779
+ xtc_probability : float = 0.0 ,
780
+ xtc_threshold : float = 0.1 ,
781
+ dry_multiplier : float = 0.0 ,
782
+ dry_allowed_length : int = 2 ,
783
+ dry_base : float = 1.75 ,
784
+ dry_range : int = 0 ,
785
+ dry_seq_breakers : list [str ] = [],
770
786
penalize_nl : bool = True ,
771
787
logits_processor : Optional [LogitsProcessorList ] = None ,
772
788
grammar : Optional [LlamaGrammar ] = None ,
@@ -802,6 +818,13 @@ def sample(
802
818
mirostat_mode = mirostat_mode ,
803
819
mirostat_tau = mirostat_tau ,
804
820
mirostat_eta = mirostat_eta ,
821
+ xtc_probability = xtc_probability ,
822
+ xtc_threshold = xtc_threshold ,
823
+ dry_multiplier = dry_multiplier ,
824
+ dry_allowed_length = dry_allowed_length ,
825
+ dry_base = dry_base ,
826
+ dry_range = dry_range ,
827
+ dry_seq_breakers = dry_seq_breakers ,
805
828
penalize_nl = penalize_nl ,
806
829
logits_processor = logits_processor ,
807
830
grammar = grammar ,
@@ -831,6 +854,13 @@ def generate(
831
854
mirostat_mode : int = 0 ,
832
855
mirostat_tau : float = 5.0 ,
833
856
mirostat_eta : float = 0.1 ,
857
+ xtc_probability : float = 0.0 ,
858
+ xtc_threshold : float = 0.1 ,
859
+ dry_multiplier : float = 0.0 ,
860
+ dry_allowed_length : int = 2 ,
861
+ dry_base : float = 1.75 ,
862
+ dry_range : int = 0 ,
863
+ dry_seq_breakers : list [str ] = [],
834
864
penalize_nl : bool = True ,
835
865
logits_processor : Optional [LogitsProcessorList ] = None ,
836
866
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
@@ -870,6 +900,13 @@ def generate(
870
900
mirostat_mode = mirostat_mode ,
871
901
mirostat_tau = mirostat_tau ,
872
902
mirostat_eta = mirostat_eta ,
903
+ xtc_probability = xtc_probability ,
904
+ xtc_threshold = xtc_threshold ,
905
+ dry_multiplier = dry_multiplier ,
906
+ dry_allowed_length = dry_allowed_length ,
907
+ dry_base = dry_base ,
908
+ dry_range = dry_range ,
909
+ dry_seq_breakers = dry_seq_breakers ,
873
910
penalize_nl = penalize_nl ,
874
911
logits_processor = logits_processor ,
875
912
grammar = grammar ,
@@ -922,6 +959,13 @@ def generate(
922
959
mirostat_mode = mirostat_mode ,
923
960
mirostat_tau = mirostat_tau ,
924
961
mirostat_eta = mirostat_eta ,
962
+ xtc_probability = xtc_probability ,
963
+ xtc_threshold = xtc_threshold ,
964
+ dry_multiplier = dry_multiplier ,
965
+ dry_allowed_length = dry_allowed_length ,
966
+ dry_base = dry_base ,
967
+ dry_range = dry_range ,
968
+ dry_seq_breakers = dry_seq_breakers ,
925
969
logits_processor = logits_processor ,
926
970
grammar = grammar ,
927
971
penalize_nl = penalize_nl ,
@@ -1138,6 +1182,13 @@ def _create_completion(
1138
1182
mirostat_mode : int = 0 ,
1139
1183
mirostat_tau : float = 5.0 ,
1140
1184
mirostat_eta : float = 0.1 ,
1185
+ xtc_probability : float = 0.0 ,
1186
+ xtc_threshold : float = 0.1 ,
1187
+ dry_multiplier : float = 0.0 ,
1188
+ dry_allowed_length : int = 2 ,
1189
+ dry_base : float = 1.75 ,
1190
+ dry_range : int = 0 ,
1191
+ dry_seq_breakers : list [str ] = [],
1141
1192
model : Optional [str ] = None ,
1142
1193
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1143
1194
logits_processor : Optional [LogitsProcessorList ] = None ,
@@ -1326,6 +1377,13 @@ def logit_bias_processor(
1326
1377
mirostat_mode = mirostat_mode ,
1327
1378
mirostat_tau = mirostat_tau ,
1328
1379
mirostat_eta = mirostat_eta ,
1380
+ xtc_probability = xtc_probability ,
1381
+ xtc_threshold = xtc_threshold ,
1382
+ dry_multiplier = dry_multiplier ,
1383
+ dry_allowed_length = dry_allowed_length ,
1384
+ dry_base = dry_base ,
1385
+ dry_range = dry_range ,
1386
+ dry_seq_breakers = dry_seq_breakers ,
1329
1387
frequency_penalty = frequency_penalty ,
1330
1388
presence_penalty = presence_penalty ,
1331
1389
repeat_penalty = repeat_penalty ,
@@ -1758,6 +1816,13 @@ def create_completion(
1758
1816
mirostat_mode : int = 0 ,
1759
1817
mirostat_tau : float = 5.0 ,
1760
1818
mirostat_eta : float = 0.1 ,
1819
+ xtc_probability : float = 0.0 ,
1820
+ xtc_threshold : float = 0.1 ,
1821
+ dry_multiplier : float = 0.0 ,
1822
+ dry_allowed_length : int = 2 ,
1823
+ dry_base : float = 1.75 ,
1824
+ dry_range : int = 0 ,
1825
+ dry_seq_breakers : list [str ] = [],
1761
1826
model : Optional [str ] = None ,
1762
1827
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1763
1828
logits_processor : Optional [LogitsProcessorList ] = None ,
@@ -1821,6 +1886,13 @@ def create_completion(
1821
1886
mirostat_mode = mirostat_mode ,
1822
1887
mirostat_tau = mirostat_tau ,
1823
1888
mirostat_eta = mirostat_eta ,
1889
+ xtc_probability = xtc_probability ,
1890
+ xtc_threshold = xtc_threshold ,
1891
+ dry_multiplier = dry_multiplier ,
1892
+ dry_allowed_length = dry_allowed_length ,
1893
+ dry_base = dry_base ,
1894
+ dry_range = dry_range ,
1895
+ dry_seq_breakers = dry_seq_breakers ,
1824
1896
model = model ,
1825
1897
stopping_criteria = stopping_criteria ,
1826
1898
logits_processor = logits_processor ,
@@ -1855,6 +1927,13 @@ def __call__(
1855
1927
mirostat_mode : int = 0 ,
1856
1928
mirostat_tau : float = 5.0 ,
1857
1929
mirostat_eta : float = 0.1 ,
1930
+ xtc_probability : float = 0.0 ,
1931
+ xtc_threshold : float = 0.1 ,
1932
+ dry_multiplier : float = 0.0 ,
1933
+ dry_allowed_length : int = 2 ,
1934
+ dry_base : float = 1.75 ,
1935
+ dry_range : int = 0 ,
1936
+ dry_seq_breakers : list [str ] = [],
1858
1937
model : Optional [str ] = None ,
1859
1938
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1860
1939
logits_processor : Optional [LogitsProcessorList ] = None ,
@@ -1918,6 +1997,13 @@ def __call__(
1918
1997
mirostat_mode = mirostat_mode ,
1919
1998
mirostat_tau = mirostat_tau ,
1920
1999
mirostat_eta = mirostat_eta ,
2000
+ xtc_probability = xtc_probability ,
2001
+ xtc_threshold = xtc_threshold ,
2002
+ dry_multiplier = dry_multiplier ,
2003
+ dry_allowed_length = dry_allowed_length ,
2004
+ dry_base = dry_base ,
2005
+ dry_range = dry_range ,
2006
+ dry_seq_breakers = dry_seq_breakers ,
1921
2007
model = model ,
1922
2008
stopping_criteria = stopping_criteria ,
1923
2009
logits_processor = logits_processor ,
@@ -1949,6 +2035,13 @@ def create_chat_completion(
1949
2035
mirostat_mode : int = 0 ,
1950
2036
mirostat_tau : float = 5.0 ,
1951
2037
mirostat_eta : float = 0.1 ,
2038
+ xtc_probability : float = 0.0 ,
2039
+ xtc_threshold : float = 0.1 ,
2040
+ dry_multiplier : float = 0.0 ,
2041
+ dry_allowed_length : int = 2 ,
2042
+ dry_base : float = 1.75 ,
2043
+ dry_range : int = 0 ,
2044
+ dry_seq_breakers : list [str ] = [],
1952
2045
model : Optional [str ] = None ,
1953
2046
logits_processor : Optional [LogitsProcessorList ] = None ,
1954
2047
grammar : Optional [LlamaGrammar ] = None ,
@@ -2022,6 +2115,13 @@ def create_chat_completion(
2022
2115
mirostat_mode = mirostat_mode ,
2023
2116
mirostat_tau = mirostat_tau ,
2024
2117
mirostat_eta = mirostat_eta ,
2118
+ xtc_probability = xtc_probability ,
2119
+ xtc_threshold = xtc_threshold ,
2120
+ dry_multiplier = dry_multiplier ,
2121
+ dry_allowed_length = dry_allowed_length ,
2122
+ dry_base = dry_base ,
2123
+ dry_range = dry_range ,
2124
+ dry_seq_breakers = dry_seq_breakers ,
2025
2125
model = model ,
2026
2126
logits_processor = logits_processor ,
2027
2127
grammar = grammar ,
0 commit comments