1
1
# SPDX-License-Identifier: Apache-2.0
2
2
3
3
import torch
4
- from typing import List , Set , Optional , Dict , Any
5
- from vllm .v1 .sample .metadata import SamplingMetadata
4
+ from typing import List , Optional , Dict , Any
5
+ from vllm .v1 .sample .metadata import SamplingMetadata
6
6
7
7
from vllm .model_executor .layers .utils import apply_penalties
8
8
from vllm .utils import is_pin_memory_available , make_tensor_with_pad
@@ -69,8 +69,8 @@ def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
69
69
_DRY_DEFAULT_MULTIPLIER = 0.0
70
70
_DRY_DEFAULT_BASE = 1.0
71
71
_DRY_DEFAULT_ALLOWED_LEN = 3
72
- _DRY_DEFAULT_RANGE = 1500
73
- _DRY_DEFAULT_BREAKERS : Set [int ] = set ()
72
+ _DRY_DEFAULT_RANGE = 1500
73
+ _DRY_DEFAULT_BREAKERS : set [int ] = set ()
74
74
75
75
76
76
def apply_dry (
@@ -113,8 +113,9 @@ def apply_dry(
113
113
for irow in range (batch_size ):
114
114
# Ensure sampling_metadata has data for this row index
115
115
if irow >= len (sampling_metadata .extra_data ):
116
- # If metadata doesn't cover this row (shouldn't happen in normal flow), skip
117
- continue
116
+ # If metadata doesn't cover this row (shouldn't happen in
117
+ # normal flow), skip
118
+ continue
118
119
extra_data = sampling_metadata .extra_data [irow ]
119
120
if not extra_data :
120
121
continue
@@ -135,7 +136,9 @@ def apply_dry(
135
136
# Assuming prompt_token_ids is available and correctly indexed
136
137
# Need prompt_lens if prompt_token_ids is padded
137
138
prompt_len_attr = getattr (sampling_metadata , 'prompt_lens' , None )
138
- current_prompt_len = prompt_len_attr [irow ] if prompt_len_attr and irow < len (prompt_len_attr ) else None
139
+ current_prompt_len = None
140
+ if prompt_len_attr and irow < len (prompt_len_attr ):
141
+ current_prompt_len = prompt_len_attr [irow ]
139
142
140
143
# Ensure prompt_token_ids covers this row
141
144
if irow >= sampling_metadata .prompt_token_ids .shape [0 ]:
@@ -186,8 +189,10 @@ def apply_dry(
186
189
max_ngram = 0
187
190
for offset in range (1 , min (seq_len , _DRY_MAX_NGRAM + 1 )):
188
191
check_idx = - offset - 1
189
- if check_idx < - seq_len : break
190
- if break_mask [check_idx ]: break
192
+ if check_idx < - seq_len :
193
+ break
194
+ if break_mask [check_idx ]:
195
+ break
191
196
max_ngram = offset
192
197
193
198
if max_ngram < allowed_length :
@@ -203,43 +208,43 @@ def apply_dry(
203
208
if len (endpoint_indices ) > _DRY_MAX_OCCURRENCES :
204
209
endpoint_indices = endpoint_indices [- _DRY_MAX_OCCURRENCES :]
205
210
206
- # 6. Calculate match lengths for potential next tokens
207
- ngram_lens = torch .zeros (vocab_size , dtype = torch .int32 , device = device )
208
- found_early_exit_match = False
209
-
210
- for idx_tensor in reversed (endpoint_indices ):
211
- idx = idx_tensor .item ()
211
+ # 6. Iterate through found endpoints to compare n-grams
212
+ ngram_occurrence_count = 0
213
+ for end_idx in endpoint_indices :
214
+ # Compare n-grams backwards from endpoints
212
215
match_len = 0
213
- for unwind in range (1 , max_ngram + 1 ):
214
- current_idx = idx - unwind
215
- history_idx = seq_len - 1 - unwind
216
- if current_idx < 0 : break
217
- # Check breaks using the precomputed mask
218
- if break_mask [current_idx ] or break_mask [history_idx ]: break
219
- if token_seq_tensor [current_idx ] != token_seq_tensor [history_idx ]: break
220
- match_len = unwind
221
-
222
- if match_len >= allowed_length : # Match length must meet minimum
223
- next_tok_idx = idx + 1
224
- if next_tok_idx < seq_len :
225
- next_tok = token_seq_tensor [next_tok_idx ].item ()
226
- # Use match_len as the length of the *matched* sequence
227
- new_len = match_len
228
- current_max = ngram_lens [next_tok ].item ()
229
- ngram_lens [next_tok ] = max (current_max , new_len )
230
- if new_len >= _DRY_EARLY_EXIT_MATCH_LEN :
231
- found_early_exit_match = True
232
-
233
- if found_early_exit_match : break
234
-
235
- # 7. Apply penalty to logits for this row
236
- penalty_mask = ngram_lens > 0
237
- if penalty_mask .any ():
238
- match_lengths_for_penalty = ngram_lens [penalty_mask ]
239
- # Clamp exponent >= 0
240
- exponents = (match_lengths_for_penalty .float () - allowed_length ).clamp_ (min = 0.0 )
241
- scales = base ** exponents
242
- logits [irow , penalty_mask ] -= multiplier * scales
243
- # --- End of DRY logic for row ---
216
+ for ngram_offset in range (max_ngram ):
217
+ p_idx = end_idx - ngram_offset
218
+ q_idx = seq_len - 1 - ngram_offset
219
+ if p_idx < 0 or q_idx < 0 :
220
+ break # Should not happen with logic checks
221
+
222
+ # Early exit inner loop if mismatch
223
+ if token_seq_tensor [p_idx ] != token_seq_tensor [q_idx ]:
224
+ break
225
+
226
+ # Do not count matches across breaker tokens
227
+ if break_mask [p_idx - 1 ] or break_mask [q_idx - 1 ]:
228
+ break
229
+
230
+ match_len = ngram_offset + 1
231
+
232
+ # Penalize the token that would continue the n-gram
233
+ next_token_idx = end_idx + 1
234
+ if 0 <= next_token_idx < seq_len :
235
+ next_token = token_seq_tensor [next_token_idx ]
236
+ if not break_mask [next_token_idx - 1 ]: # Check break mask before penalizing
237
+ if 0 <= next_token < vocab_size : # Ensure token is within vocab bounds
238
+ penalty = (multiplier ** (ngram_occurrence_count + 1 )) * (base ** match_len )
239
+ logits [irow , next_token ] /= penalty
240
+ ngram_occurrence_count += 1
241
+
242
+ # Stop checking this endpoint if max occurrences reached
243
+ if ngram_occurrence_count >= _DRY_MAX_OCCURRENCES :
244
+ break
245
+
246
+ # Early exit outer loop if a long match is found
247
+ if match_len >= _DRY_EARLY_EXIT_MATCH_LEN :
248
+ break
244
249
245
250
return logits
0 commit comments