Skip to content

Commit df53c6e

Browse files
gantemenhguinamyeroberts
authored
Generate: add min_p sampling (#30639)
* min_p * more relaxed test to avoid numerical issues * Update src/transformers/generation/logits_process.py Co-authored-by: menhguin <[email protected]> * Update src/transformers/generation/configuration_utils.py Co-authored-by: menhguin <[email protected]> * docstring clarifications * PR comments * Update tests/generation/test_logits_process.py Co-authored-by: amyeroberts <[email protected]> * make fixup --------- Co-authored-by: menhguin <[email protected]> Co-authored-by: amyeroberts <[email protected]>
1 parent 297b732 commit df53c6e

File tree

8 files changed

+147
-0
lines changed

8 files changed

+147
-0
lines changed

docs/source/en/internal/generation_utils.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ generation.
167167
[[autodoc]] MinNewTokensLengthLogitsProcessor
168168
- __call__
169169

170+
[[autodoc]] MinPLogitsWarper
171+
- __call__
172+
170173
[[autodoc]] NoBadWordsLogitsProcessor
171174
- __call__
172175

src/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,7 @@
12151215
"MaxTimeCriteria",
12161216
"MinLengthLogitsProcessor",
12171217
"MinNewTokensLengthLogitsProcessor",
1218+
"MinPLogitsWarper",
12181219
"NoBadWordsLogitsProcessor",
12191220
"NoRepeatNGramLogitsProcessor",
12201221
"PhrasalConstraint",
@@ -5770,6 +5771,7 @@
57705771
MaxTimeCriteria,
57715772
MinLengthLogitsProcessor,
57725773
MinNewTokensLengthLogitsProcessor,
5774+
MinPLogitsWarper,
57735775
NoBadWordsLogitsProcessor,
57745776
NoRepeatNGramLogitsProcessor,
57755777
PhrasalConstraint,

src/transformers/generation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
"LogitsWarper",
6565
"MinLengthLogitsProcessor",
6666
"MinNewTokensLengthLogitsProcessor",
67+
"MinPLogitsWarper",
6768
"NoBadWordsLogitsProcessor",
6869
"NoRepeatNGramLogitsProcessor",
6970
"PrefixConstrainedLogitsProcessor",
@@ -204,6 +205,7 @@
204205
LogitsWarper,
205206
MinLengthLogitsProcessor,
206207
MinNewTokensLengthLogitsProcessor,
208+
MinPLogitsWarper,
207209
NoBadWordsLogitsProcessor,
208210
NoRepeatNGramLogitsProcessor,
209211
PrefixConstrainedLogitsProcessor,

src/transformers/generation/configuration_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ class GenerationConfig(PushToHubMixin):
133133
top_p (`float`, *optional*, defaults to 1.0):
134134
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to
135135
`top_p` or higher are kept for generation.
136+
min_p (`float`, *optional*):
137+
Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
138+
value between 0 and 1. Typical values are in the 0.01-0.2 range, comparably selective as setting `top_p` in
139+
the 0.99-0.8 range (use the opposite of normal `top_p` values).
136140
typical_p (`float`, *optional*, defaults to 1.0):
137141
Local typicality measures how similar the conditional probability of predicting a target token next is to
138142
the expected conditional probability of predicting a random token next, given the partial text already
@@ -306,6 +310,7 @@ def __init__(self, **kwargs):
306310
self.temperature = kwargs.pop("temperature", 1.0)
307311
self.top_k = kwargs.pop("top_k", 50)
308312
self.top_p = kwargs.pop("top_p", 1.0)
313+
self.min_p = kwargs.pop("min_p", None)
309314
self.typical_p = kwargs.pop("typical_p", 1.0)
310315
self.epsilon_cutoff = kwargs.pop("epsilon_cutoff", 0.0)
311316
self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0)

src/transformers/generation/logits_process.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,83 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
520520
return scores_processed
521521

522522

523+
class MinPLogitsWarper(LogitsWarper):
524+
"""
525+
[`LogitsWarper`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the
526+
probability of the most likely token. As a result, the filter becomes more agressive in the presence of
527+
high-probability tokens, which is a sign of a confident output that we shouldn't deviate from.
528+
529+
Often used together with [`TemperatureLogitsWarper`]. Used as an alternative to [`TopPLogitsWarper`] and
530+
[`TopKLogitsWarper`].
531+
532+
Created by @menhguin and @kalomaze (github handles). Code adapted from [this external PR](https://github.com/oobabooga/text-generation-webui/pull/4449/files)
533+
534+
Args:
535+
min_p (`float`):
536+
Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
537+
value between 0 and 1. Typical values are in the 0.01-0.2 range, comparably selective as setting `top_p` in
538+
the 0.99-0.8 range (use the opposite of normal `top_p` values).
539+
filter_value (`float`, *optional*, defaults to -inf):
540+
All filtered values will be set to this float value.
541+
min_tokens_to_keep (`int`, *optional*, defaults to 1):
542+
Minimum number of tokens that cannot be filtered.
543+
544+
Examples:
545+
546+
```python
547+
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
548+
549+
>>> set_seed(1)
550+
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
551+
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
552+
553+
>>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt")
554+
555+
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
556+
>>> outputs = model.generate(**inputs, do_sample=True)
557+
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
558+
A sequence: 1, 2, 3 | < 4 (left-hand pointer) ;
559+
<BLANKLINE>
560+
<BLANKLINE>
561+
562+
>>> # With `min_p` sampling, the output gets restricted to high-probability tokens.
563+
>>> # Pro tip: In practice, LLMs use `min_p` in the 0.01-0.2 range.
564+
>>> outputs = model.generate(**inputs, do_sample=True, min_p=0.1)
565+
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
566+
A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9
567+
```
568+
"""
569+
570+
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
571+
if not (0 <= min_p <= 1.0):
572+
raise ValueError(f"`min_p` has to be a float in the [0, 1] interval, but is {min_p}")
573+
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
574+
raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
575+
576+
self.min_p = min_p
577+
self.filter_value = filter_value
578+
self.min_tokens_to_keep = min_tokens_to_keep
579+
580+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
581+
# Convert logits to probabilities
582+
probs = torch.softmax(scores, dim=-1)
583+
# Get the probability of the top token for each sequence in the batch
584+
top_probs, _ = probs.max(dim=-1, keepdim=True)
585+
# Calculate the actual min_p threshold by scaling min_p with the top token's probability
586+
scaled_min_p = self.min_p * top_probs
587+
# Create a mask for tokens that have a probability less than the scaled min_p
588+
tokens_to_remove = probs < scaled_min_p
589+
590+
sorted_indices = torch.argsort(scores, descending=True, dim=-1)
591+
sorted_indices_to_remove = torch.gather(tokens_to_remove, dim=-1, index=sorted_indices)
592+
# Keep at least min_tokens_to_keep
593+
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = False
594+
595+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
596+
scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)
597+
return scores_processed
598+
599+
523600
class TypicalLogitsWarper(LogitsWarper):
524601
r"""
525602
[`LogitsWarper`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens whose

src/transformers/generation/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
LogitsProcessorList,
6262
MinLengthLogitsProcessor,
6363
MinNewTokensLengthLogitsProcessor,
64+
MinPLogitsWarper,
6465
NoBadWordsLogitsProcessor,
6566
NoRepeatNGramLogitsProcessor,
6667
PrefixConstrainedLogitsProcessor,
@@ -741,6 +742,9 @@ def _get_logits_warper(
741742
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
742743
if generation_config.top_p is not None and generation_config.top_p < 1.0:
743744
warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
745+
if generation_config.min_p is not None:
746+
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
747+
warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
744748
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
745749
warpers.append(
746750
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)

src/transformers/utils/dummy_pt_objects.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,13 @@ def __init__(self, *args, **kwargs):
303303
requires_backends(self, ["torch"])
304304

305305

306+
class MinPLogitsWarper(metaclass=DummyObject):
307+
_backends = ["torch"]
308+
309+
def __init__(self, *args, **kwargs):
310+
requires_backends(self, ["torch"])
311+
312+
306313
class NoBadWordsLogitsProcessor(metaclass=DummyObject):
307314
_backends = ["torch"]
308315

tests/generation/test_logits_process.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
LogitsProcessorList,
4343
MinLengthLogitsProcessor,
4444
MinNewTokensLengthLogitsProcessor,
45+
MinPLogitsWarper,
4546
NoBadWordsLogitsProcessor,
4647
NoRepeatNGramLogitsProcessor,
4748
PrefixConstrainedLogitsProcessor,
@@ -304,6 +305,52 @@ def test_top_p_dist_warper(self):
304305
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
305306
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])
306307

308+
def test_min_p_dist_warper(self):
309+
input_ids = None
310+
vocab_size = 10
311+
batch_size = 2
312+
313+
# create distribution and take log (inverse to Softmax as taken in MinPLogitsWarper)
314+
dist = torch.log(
315+
torch.tensor(
316+
[
317+
[0.9, 0.0274, 0.047, 0.0274], # two tokens should be kept (0.047 > 0.9*0.05=0.045)
318+
[0.15, 0.3, 0.3, 0.25], # all should be kept -- no high-probability token
319+
[0.97, 0.01, 0.01, 0.01], # only the first token should be kept
320+
],
321+
device=torch_device,
322+
dtype=torch.float,
323+
)
324+
)
325+
326+
min_p_warp = MinPLogitsWarper(0.05)
327+
filtered_dist = torch.exp(min_p_warp(input_ids, dist))
328+
329+
# exp (-inf) => 0
330+
EXPECTED_FILTERED_DIST = torch.tensor(
331+
[[0.9, 0.0, 0.047, 0.0], [0.15, 0.3, 0.3, 0.25], [0.97, 0.0, 0.0, 0.0]],
332+
device=torch_device,
333+
dtype=torch.float,
334+
)
335+
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
336+
337+
# processor should not change logits in-place
338+
self.assertFalse(torch.all(min_p_warp(input_ids, dist) == dist))
339+
340+
# check edge cases with negative and extreme logits
341+
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float) - (vocab_size // 2)
342+
ramp_logits = ramp_logits.unsqueeze(0).repeat(batch_size, 1)
343+
344+
# make ramp_logits more extreme
345+
ramp_logits[1] = ramp_logits[1] * 100.0
346+
347+
# make sure at least 2 tokens are kept
348+
min_p_warp = MinPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
349+
filtered_dist = min_p_warp(input_ids, ramp_logits)
350+
351+
# first batch should keep two tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
352+
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
353+
307354
def test_typical_dist_warper(self):
308355
input_ids = None
309356
vocab_size = 10

0 commit comments

Comments
 (0)