Skip to content

Commit de23687

Browse files
authored
Fix repetition penalty aligned with huggingface (#1577)
1 parent 4cea74c commit de23687

File tree

2 files changed

+50
-32
lines changed

2 files changed

+50
-32
lines changed

vllm/model_executor/layers/sampler.py

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Sampler(nn.Module):
2121
1. Discard the hidden states that are not used for sampling (i.e., all
2222
tokens except the final one in each prompt).
2323
2. Compute the logits for the next tokens.
24-
3. Apply presence and frequency penalties.
24+
3. Apply presence, frequency and repetition penalties.
2525
4. Apply temperature scaling.
2626
5. Apply top-p and top-k truncation.
2727
6. Sample the next tokens.
@@ -50,14 +50,12 @@ def forward(
5050
# Apply logits processors (if any).
5151
logits = _apply_logits_processors(logits, input_metadata)
5252
# Apply presence and frequency penalties.
53-
output_tokens = _get_output_tokens(input_metadata)
54-
assert len(output_tokens) == logits.shape[0]
5553
presence_penalties, frequency_penalties, repetition_penalties = (
5654
_get_penalties(input_metadata))
5755
assert len(presence_penalties) == logits.shape[0]
5856
assert len(frequency_penalties) == logits.shape[0]
5957
assert len(repetition_penalties) == logits.shape[0]
60-
logits = _apply_penalties(logits, output_tokens, presence_penalties,
58+
logits = _apply_penalties(logits, input_metadata, presence_penalties,
6159
frequency_penalties, repetition_penalties)
6260

6361
# Apply temperature scaling.
@@ -146,7 +144,10 @@ def _get_penalties(
146144
return presence_penalties, frequency_penalties, repetition_penalties
147145

148146

149-
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
147+
def _get_prompt_and_output_tokens(
148+
input_metadata: InputMetadata
149+
) -> Tuple[List[List[int]], List[List[int]]]:
150+
prompt_tokens: List[List[int]] = []
150151
output_tokens: List[List[int]] = []
151152
for i, seq_group in enumerate(input_metadata.seq_groups):
152153
seq_ids, sampling_params = seq_group
@@ -155,11 +156,39 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
155156
# NOTE: prompt token positions do not need output tokens to
156157
# compute penalties.
157158
prompt_len = input_metadata.prompt_lens[i]
159+
prompt_tokens.extend([] for _ in range(prompt_len - 1))
158160
output_tokens.extend([] for _ in range(prompt_len - 1))
159161
for seq_id in seq_ids:
160162
seq_data = input_metadata.seq_data[seq_id]
163+
prompt_tokens.append(seq_data.prompt_token_ids)
161164
output_tokens.append(seq_data.output_token_ids)
162-
return output_tokens
165+
return prompt_tokens, output_tokens
166+
167+
168+
def _get_bin_counts_and_mask(
169+
logits: torch.Tensor,
170+
tokens: List[List[int]],
171+
vocab_size: int,
172+
num_seqs: int,
173+
) -> Tuple[torch.Tensor, torch.Tensor]:
174+
max_len = max(len(tokens) for tokens in tokens)
175+
padded_tokens = [
176+
tokens + [vocab_size] * (max_len - len(tokens)) for tokens in tokens
177+
]
178+
tokens_tensor = torch.tensor(padded_tokens,
179+
dtype=torch.long,
180+
device=logits.device)
181+
182+
# Compute the bin counts for the tokens.
183+
# vocab_size + 1 for padding.
184+
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
185+
dtype=torch.long,
186+
device=logits.device)
187+
bin_counts.scatter_add_(1, tokens_tensor, torch.ones_like(tokens_tensor))
188+
bin_counts = bin_counts[:, :vocab_size]
189+
mask = bin_counts > 0
190+
191+
return bin_counts, mask
163192

164193

165194
def _apply_logits_processors(logits: torch.Tensor,
@@ -186,15 +215,13 @@ def _apply_logits_processors(logits: torch.Tensor,
186215

187216
def _apply_penalties(
188217
logits: torch.Tensor,
189-
output_tokens: List[List[int]],
218+
input_metadata: InputMetadata,
190219
presence_penalties: List[float],
191220
frequency_penalties: List[float],
192221
repetition_penalties: List[float],
193222
) -> torch.Tensor:
194223
num_seqs, vocab_size = logits.shape
195224
for i in range(num_seqs):
196-
if not output_tokens[i]:
197-
continue
198225
p = presence_penalties[i]
199226
f = frequency_penalties[i]
200227
r = repetition_penalties[i]
@@ -206,24 +233,15 @@ def _apply_penalties(
206233
# Return early if all sequences have zero penalties.
207234
return logits
208235

209-
max_output_len = max(len(tokens) for tokens in output_tokens)
210-
padded_output_tokens = [
211-
tokens + [vocab_size] * (max_output_len - len(tokens))
212-
for tokens in output_tokens
213-
]
214-
output_tokens_tensor = torch.tensor(padded_output_tokens,
215-
dtype=torch.long,
216-
device=logits.device)
236+
prompt_tokens, output_tokens = (
237+
_get_prompt_and_output_tokens(input_metadata))
238+
assert len(prompt_tokens) == logits.shape[0]
239+
assert len(output_tokens) == logits.shape[0]
217240

218-
# Compute the bin counts for the output tokens.
219-
# vocab_size + 1 for padding.
220-
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
221-
dtype=torch.long,
222-
device=logits.device)
223-
bin_counts.scatter_add_(1, output_tokens_tensor,
224-
torch.ones_like(output_tokens_tensor))
225-
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
226-
mask = bin_counts > 0
241+
prompt_bin_counts, prompt_mask = _get_bin_counts_and_mask(
242+
logits, prompt_tokens, vocab_size, num_seqs)
243+
output_bin_counts, output_mask = _get_bin_counts_and_mask(
244+
logits, output_tokens, vocab_size, num_seqs)
227245

228246
repetition_penalties = torch.tensor(repetition_penalties,
229247
dtype=logits.dtype,
@@ -236,14 +254,14 @@ def _apply_penalties(
236254
device=logits.device)
237255

238256
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
239-
repetition_penalties[~mask] = 1.0
257+
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
240258
logits = torch.where(logits > 0, logits / repetition_penalties,
241259
logits * repetition_penalties)
242260

243261
# We follow the definition in OpenAI API.
244262
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
245-
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
246-
logits -= presence_penalties.unsqueeze(dim=1) * mask
263+
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
264+
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
247265
return logits
248266

249267

vllm/sampling_params.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ class SamplingParams:
4242
model to use new tokens, while values < 0 encourage the model to
4343
repeat tokens.
4444
repetition_penalty: Float that penalizes new tokens based on whether
45-
they appear in the generated text so far. Values > 1 encourage the
46-
model to use new tokens, while values < 1 encourage the model to
47-
repeat tokens.
45+
they appear in the prompt and the generated text so far. Values > 1
46+
encourage the model to use new tokens, while values < 1 encourage
47+
the model to repeat tokens.
4848
temperature: Float that controls the randomness of the sampling. Lower
4949
values make the model more deterministic, while higher values make
5050
the model more random. Zero means greedy sampling.

0 commit comments

Comments
 (0)