Skip to content

Commit 299bd65

Browse files
committed
Added support for min_p (#921)
* Added support for min_p My small contribution to this great project. Ref: ggml-org/llama.cpp#3841 Closes: abetlen/llama-cpp-python#911 * Fix for negative temp (sample_softmax)
1 parent 0d783f8 commit 299bd65

File tree

3 files changed

+54
-2
lines changed

3 files changed

+54
-2
lines changed

llama_cpp/llama.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -1046,6 +1046,8 @@ def sample(
10461046
self,
10471047
top_k: int = 40,
10481048
top_p: float = 0.95,
1049+
min_p: float = 0.05,
1050+
typical_p: float = 1.0,
10491051
temp: float = 0.80,
10501052
repeat_penalty: float = 1.1,
10511053
frequency_penalty: float = 0.0,
@@ -1108,7 +1110,10 @@ def sample(
11081110
grammar=grammar,
11091111
)
11101112

1111-
if temp == 0.0:
1113+
if temp < 0.0:
1114+
self._ctx.sample_softmax(candidates=self._candidates)
1115+
id = self._candidates.candidates.data[0].id
1116+
elif temp == 0.0:
11121117
id = self._ctx.sample_token_greedy(candidates=self._candidates)
11131118
elif mirostat_mode == 1:
11141119
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
@@ -1130,8 +1135,9 @@ def sample(
11301135
else:
11311136
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
11321137
self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
1133-
self._ctx.sample_typical(candidates=self._candidates, p=1.0, min_keep=1)
1138+
self._ctx.sample_typical(candidates=self._candidates, p=typical_p, min_keep=1)
11341139
self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
1140+
self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1)
11351141
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
11361142
id = self._ctx.sample_token(candidates=self._candidates)
11371143
if grammar is not None:
@@ -1143,6 +1149,8 @@ def generate(
11431149
tokens: Sequence[int],
11441150
top_k: int = 40,
11451151
top_p: float = 0.95,
1152+
min_p: float = 0.05,
1153+
typical_p: float = 1.0,
11461154
temp: float = 0.80,
11471155
repeat_penalty: float = 1.1,
11481156
reset: bool = True,
@@ -1200,6 +1208,8 @@ def generate(
12001208
token = self.sample(
12011209
top_k=top_k,
12021210
top_p=top_p,
1211+
min_p=min_p,
1212+
typical_p=typical_p,
12031213
temp=temp,
12041214
repeat_penalty=repeat_penalty,
12051215
frequency_penalty=frequency_penalty,
@@ -1298,6 +1308,8 @@ def _create_completion(
12981308
max_tokens: Optional[int] = 16,
12991309
temperature: float = 0.8,
13001310
top_p: float = 0.95,
1311+
min_p: float = 0.05,
1312+
typical_p: float = 1.0,
13011313
logprobs: Optional[int] = None,
13021314
echo: bool = False,
13031315
stop: Optional[Union[str, List[str]]] = [],
@@ -1398,6 +1410,8 @@ def _create_completion(
13981410
prompt_tokens,
13991411
top_k=top_k,
14001412
top_p=top_p,
1413+
min_p=min_p,
1414+
typical_p=typical_p,
14011415
temp=temperature,
14021416
tfs_z=tfs_z,
14031417
mirostat_mode=mirostat_mode,
@@ -1772,6 +1786,8 @@ def create_completion(
17721786
max_tokens: Optional[int] = 16,
17731787
temperature: float = 0.8,
17741788
top_p: float = 0.95,
1789+
min_p: float = 0.05,
1790+
typical_p: float = 1.0,
17751791
logprobs: Optional[int] = None,
17761792
echo: bool = False,
17771793
stop: Optional[Union[str, List[str]]] = [],
@@ -1818,6 +1834,8 @@ def create_completion(
18181834
max_tokens=max_tokens,
18191835
temperature=temperature,
18201836
top_p=top_p,
1837+
min_p=min_p,
1838+
typical_p=typical_p,
18211839
logprobs=logprobs,
18221840
echo=echo,
18231841
stop=stop,
@@ -1849,6 +1867,8 @@ def __call__(
18491867
max_tokens: int = 128,
18501868
temperature: float = 0.8,
18511869
top_p: float = 0.95,
1870+
min_p: float = 0.05,
1871+
typical_p: float = 1.0,
18521872
logprobs: Optional[int] = None,
18531873
echo: bool = False,
18541874
stop: Optional[Union[str, List[str]]] = [],
@@ -1895,6 +1915,8 @@ def __call__(
18951915
max_tokens=max_tokens,
18961916
temperature=temperature,
18971917
top_p=top_p,
1918+
min_p=min_p,
1919+
typical_p=typical_p,
18981920
logprobs=logprobs,
18991921
echo=echo,
19001922
stop=stop,
@@ -1924,6 +1946,8 @@ def create_chat_completion(
19241946
temperature: float = 0.2,
19251947
top_p: float = 0.95,
19261948
top_k: int = 40,
1949+
min_p: float = 0.05,
1950+
typical_p: float = 1.0,
19271951
stream: bool = False,
19281952
stop: Optional[Union[str, List[str]]] = [],
19291953
seed: Optional[int] = None,
@@ -1970,6 +1994,8 @@ def create_chat_completion(
19701994
temperature=temperature,
19711995
top_p=top_p,
19721996
top_k=top_k,
1997+
min_p=min_p,
1998+
typical_p=typical_p,
19731999
stream=stream,
19742000
stop=stop,
19752001
seed=seed,

llama_cpp/llama_chat_format.py

+16
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def __call__(
2626
temperature: float = 0.2,
2727
top_p: float = 0.95,
2828
top_k: int = 40,
29+
min_p: float = 0.05,
30+
typical_p: float = 1.0,
2931
stream: bool = False,
3032
stop: Optional[Union[str, List[str]]] = [],
3133
seed: Optional[int] = None,
@@ -287,6 +289,8 @@ def basic_create_chat_completion(
287289
temperature: float = 0.2,
288290
top_p: float = 0.95,
289291
top_k: int = 40,
292+
min_p: float = 0.05,
293+
typical_p: float = 1.0,
290294
stream: bool = False,
291295
stop: Optional[Union[str, List[str]]] = [],
292296
seed: Optional[int] = None,
@@ -330,6 +334,8 @@ def basic_create_chat_completion(
330334
temperature=temperature,
331335
top_p=top_p,
332336
top_k=top_k,
337+
min_p=min_p,
338+
typical_p=typical_p,
333339
stream=stream,
334340
stop=stop,
335341
seed=seed,
@@ -579,6 +585,8 @@ def functionary_chat_handler(
579585
temperature: float = 0.2,
580586
top_p: float = 0.95,
581587
top_k: int = 40,
588+
min_p: float = 0.05,
589+
typical_p: float = 1.0,
582590
stream: bool = False,
583591
stop: Optional[Union[str, List[str]]] = [],
584592
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
@@ -761,6 +769,8 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
761769
temperature=temperature,
762770
top_p=top_p,
763771
top_k=top_k,
772+
min_p=min_p,
773+
typical_p=typical_p,
764774
stream=stream,
765775
stop=["user:", "</s>"],
766776
max_tokens=max_tokens,
@@ -831,6 +841,8 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
831841
temperature=temperature,
832842
top_p=top_p,
833843
top_k=top_k,
844+
min_p=min_p,
845+
typical_p=typical_p,
834846
presence_penalty=presence_penalty,
835847
frequency_penalty=frequency_penalty,
836848
repeat_penalty=repeat_penalty,
@@ -929,6 +941,8 @@ def __call__(
929941
temperature: float = 0.2,
930942
top_p: float = 0.95,
931943
top_k: int = 40,
944+
min_p: float = 0.05,
945+
typical_p: float = 1.0,
932946
stream: bool = False,
933947
stop: Optional[Union[str, List[str]]] = [],
934948
response_format: Optional[
@@ -1045,6 +1059,8 @@ def __call__(
10451059
temperature=temperature,
10461060
top_p=top_p,
10471061
top_k=top_k,
1062+
min_p=min_p,
1063+
typical_p=typical_p,
10481064
stream=stream,
10491065
stop=stop,
10501066
max_tokens=max_tokens,

llama_cpp/server/app.py

+10
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,14 @@ async def get_event_publisher(
521521
+ "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.",
522522
)
523523

524+
min_p_field = Field(
525+
default=0.05,
526+
ge=0.0,
527+
le=1.0,
528+
description="Sets a minimum base probability threshold for token selection.\n\n"
529+
+ "The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter min_p represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with min_p=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.",
530+
)
531+
524532
stop_field = Field(
525533
default=None,
526534
description="A list of tokens at which to stop generation. If None, no stop tokens are used.",
@@ -593,6 +601,7 @@ class CreateCompletionRequest(BaseModel):
593601
max_tokens: int = max_tokens_field
594602
temperature: float = temperature_field
595603
top_p: float = top_p_field
604+
min_p: float = min_p_field
596605
echo: bool = Field(
597606
default=False,
598607
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
@@ -788,6 +797,7 @@ class CreateChatCompletionRequest(BaseModel):
788797
)
789798
temperature: float = temperature_field
790799
top_p: float = top_p_field
800+
min_p: float = min_p_field
791801
stop: Optional[List[str]] = stop_field
792802
stream: bool = stream_field
793803
presence_penalty: Optional[float] = presence_penalty_field

0 commit comments

Comments
 (0)