Skip to content

Commit f11a0e2

Browse files
committed
fix: linter for dry issues
1 parent 03ee5df commit f11a0e2

File tree

1 file changed

+51
-46
lines changed

1 file changed

+51
-46
lines changed

vllm/v1/sample/ops/penalties.py

Lines changed: 51 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import torch
4-
from typing import List, Set, Optional, Dict, Any
5-
from vllm.v1.sample.metadata import SamplingMetadata
4+
from typing import List, Optional, Dict, Any
5+
from vllm.v1.sample.metadata import SamplingMetadata
66

77
from vllm.model_executor.layers.utils import apply_penalties
88
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
@@ -69,8 +69,8 @@ def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
6969
_DRY_DEFAULT_MULTIPLIER = 0.0
7070
_DRY_DEFAULT_BASE = 1.0
7171
_DRY_DEFAULT_ALLOWED_LEN = 3
72-
_DRY_DEFAULT_RANGE = 1500
73-
_DRY_DEFAULT_BREAKERS: Set[int] = set()
72+
_DRY_DEFAULT_RANGE = 1500
73+
_DRY_DEFAULT_BREAKERS: set[int] = set()
7474

7575

7676
def apply_dry(
@@ -113,8 +113,9 @@ def apply_dry(
113113
for irow in range(batch_size):
114114
# Ensure sampling_metadata has data for this row index
115115
if irow >= len(sampling_metadata.extra_data):
116-
# If metadata doesn't cover this row (shouldn't happen in normal flow), skip
117-
continue
116+
# If metadata doesn't cover this row (shouldn't happen in
117+
# normal flow), skip
118+
continue
118119
extra_data = sampling_metadata.extra_data[irow]
119120
if not extra_data:
120121
continue
@@ -135,7 +136,9 @@ def apply_dry(
135136
# Assuming prompt_token_ids is available and correctly indexed
136137
# Need prompt_lens if prompt_token_ids is padded
137138
prompt_len_attr = getattr(sampling_metadata, 'prompt_lens', None)
138-
current_prompt_len = prompt_len_attr[irow] if prompt_len_attr and irow < len(prompt_len_attr) else None
139+
current_prompt_len = None
140+
if prompt_len_attr and irow < len(prompt_len_attr):
141+
current_prompt_len = prompt_len_attr[irow]
139142

140143
# Ensure prompt_token_ids covers this row
141144
if irow >= sampling_metadata.prompt_token_ids.shape[0]:
@@ -186,8 +189,10 @@ def apply_dry(
186189
max_ngram = 0
187190
for offset in range(1, min(seq_len, _DRY_MAX_NGRAM + 1)):
188191
check_idx = -offset - 1
189-
if check_idx < -seq_len: break
190-
if break_mask[check_idx]: break
192+
if check_idx < -seq_len:
193+
break
194+
if break_mask[check_idx]:
195+
break
191196
max_ngram = offset
192197

193198
if max_ngram < allowed_length:
@@ -203,43 +208,43 @@ def apply_dry(
203208
if len(endpoint_indices) > _DRY_MAX_OCCURRENCES:
204209
endpoint_indices = endpoint_indices[-_DRY_MAX_OCCURRENCES:]
205210

206-
# 6. Calculate match lengths for potential next tokens
207-
ngram_lens = torch.zeros(vocab_size, dtype=torch.int32, device=device)
208-
found_early_exit_match = False
209-
210-
for idx_tensor in reversed(endpoint_indices):
211-
idx = idx_tensor.item()
211+
# 6. Iterate through found endpoints to compare n-grams
212+
ngram_occurrence_count = 0
213+
for end_idx in endpoint_indices:
214+
# Compare n-grams backwards from endpoints
212215
match_len = 0
213-
for unwind in range(1, max_ngram + 1):
214-
current_idx = idx - unwind
215-
history_idx = seq_len - 1 - unwind
216-
if current_idx < 0: break
217-
# Check breaks using the precomputed mask
218-
if break_mask[current_idx] or break_mask[history_idx]: break
219-
if token_seq_tensor[current_idx] != token_seq_tensor[history_idx]: break
220-
match_len = unwind
221-
222-
if match_len >= allowed_length: # Match length must meet minimum
223-
next_tok_idx = idx + 1
224-
if next_tok_idx < seq_len:
225-
next_tok = token_seq_tensor[next_tok_idx].item()
226-
# Use match_len as the length of the *matched* sequence
227-
new_len = match_len
228-
current_max = ngram_lens[next_tok].item()
229-
ngram_lens[next_tok] = max(current_max, new_len)
230-
if new_len >= _DRY_EARLY_EXIT_MATCH_LEN:
231-
found_early_exit_match = True
232-
233-
if found_early_exit_match: break
234-
235-
# 7. Apply penalty to logits for this row
236-
penalty_mask = ngram_lens > 0
237-
if penalty_mask.any():
238-
match_lengths_for_penalty = ngram_lens[penalty_mask]
239-
# Clamp exponent >= 0
240-
exponents = (match_lengths_for_penalty.float() - allowed_length).clamp_(min=0.0)
241-
scales = base ** exponents
242-
logits[irow, penalty_mask] -= multiplier * scales
243-
# --- End of DRY logic for row ---
216+
for ngram_offset in range(max_ngram):
217+
p_idx = end_idx - ngram_offset
218+
q_idx = seq_len - 1 - ngram_offset
219+
if p_idx < 0 or q_idx < 0:
220+
break # Should not happen with logic checks
221+
222+
# Early exit inner loop if mismatch
223+
if token_seq_tensor[p_idx] != token_seq_tensor[q_idx]:
224+
break
225+
226+
# Do not count matches across breaker tokens
227+
if break_mask[p_idx - 1] or break_mask[q_idx - 1]:
228+
break
229+
230+
match_len = ngram_offset + 1
231+
232+
# Penalize the token that would continue the n-gram
233+
next_token_idx = end_idx + 1
234+
if 0 <= next_token_idx < seq_len:
235+
next_token = token_seq_tensor[next_token_idx]
236+
if not break_mask[next_token_idx-1]: # Check break mask before penalizing
237+
if 0 <= next_token < vocab_size: # Ensure token is within vocab bounds
238+
penalty = (multiplier ** (ngram_occurrence_count + 1)) * (base ** match_len)
239+
logits[irow, next_token] /= penalty
240+
ngram_occurrence_count += 1
241+
242+
# Stop checking this endpoint if max occurrences reached
243+
if ngram_occurrence_count >= _DRY_MAX_OCCURRENCES:
244+
break
245+
246+
# Early exit outer loop if a long match is found
247+
if match_len >= _DRY_EARLY_EXIT_MATCH_LEN:
248+
break
244249

245250
return logits

0 commit comments

Comments
 (0)