Skip to content

Commit 03ee5df

Browse files
committed
feat: apply dry logic
1 parent 54a66e5 commit 03ee5df

File tree

2 files changed

+198
-2
lines changed

2 files changed

+198
-2
lines changed

vllm/v1/sample/ops/penalties.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import torch
4+
from typing import List, Set, Optional, Dict, Any
5+
from vllm.v1.sample.metadata import SamplingMetadata
46

57
from vllm.model_executor.layers.utils import apply_penalties
68
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
@@ -56,3 +58,188 @@ def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
5658
pin_memory=is_pin_memory_available(),
5759
)
5860
return output_tokens_tensor.to(device, non_blocking=True)
61+
62+
63+
# Constants for DRY
64+
_DRY_MAX_NGRAM = 12
65+
_DRY_MAX_OCCURRENCES = 8
66+
_DRY_EARLY_EXIT_MATCH_LEN = 8
67+
68+
# Default DRY parameter values
69+
_DRY_DEFAULT_MULTIPLIER = 0.0
70+
_DRY_DEFAULT_BASE = 1.0
71+
_DRY_DEFAULT_ALLOWED_LEN = 3
72+
_DRY_DEFAULT_RANGE = 1500
73+
_DRY_DEFAULT_BREAKERS: Set[int] = set()
74+
75+
76+
def apply_dry(
77+
logits: torch.Tensor,
78+
sampling_metadata: SamplingMetadata,
79+
) -> torch.Tensor:
80+
"""
81+
Applies DRY (Don't Repeat Yourself) penalty to logits based on parameters
82+
found in sampling_metadata.extra_data for each request.
83+
84+
Modifies logits in-place or returns a modified tensor.
85+
86+
Expected keys in extra_data[irow] (if DRY is active):
87+
- 'dry_multiplier' (float)
88+
- 'dry_base' (float)
89+
- 'dry_allowed_len' (int)
90+
- 'dry_range' (int)
91+
- 'dry_breakers' (List[int] or Set[int])
92+
"""
93+
batch_size, vocab_size = logits.shape
94+
device = logits.device
95+
96+
# Assume extra_data is a list of dicts, one per request
97+
# Or potentially accessed via methods on sampling_metadata
98+
# Adjust access pattern if sampling_metadata structure differs
99+
if not hasattr(sampling_metadata, 'extra_data') or sampling_metadata.extra_data is None:
100+
# If no extra_data field exists or is None, cannot apply DRY
101+
return logits
102+
103+
# Check if any request might have DRY enabled (basic check)
104+
# More robust check would involve iterating through extra_data first
105+
has_potential_dry = any(
106+
data and data.get('dry_multiplier', _DRY_DEFAULT_MULTIPLIER) > 0
107+
for data in sampling_metadata.extra_data
108+
)
109+
if not has_potential_dry:
110+
return logits
111+
112+
# --- Iterate through each request in the batch ---
113+
for irow in range(batch_size):
114+
# Ensure sampling_metadata has data for this row index
115+
if irow >= len(sampling_metadata.extra_data):
116+
# If metadata doesn't cover this row (shouldn't happen in normal flow), skip
117+
continue
118+
extra_data = sampling_metadata.extra_data[irow]
119+
if not extra_data:
120+
continue
121+
122+
# Get DRY parameters for this row, using defaults if missing
123+
multiplier = float(extra_data.get('dry_multiplier', _DRY_DEFAULT_MULTIPLIER))
124+
if multiplier <= 0.0:
125+
continue # DRY not active for this row
126+
127+
base = float(extra_data.get('dry_base', _DRY_DEFAULT_BASE))
128+
allowed_length = int(extra_data.get('dry_allowed_len', _DRY_DEFAULT_ALLOWED_LEN))
129+
dry_range = int(extra_data.get('dry_range', _DRY_DEFAULT_RANGE))
130+
breakers_input = extra_data.get('dry_breakers', _DRY_DEFAULT_BREAKERS)
131+
breakers = set(breakers_input) if breakers_input else _DRY_DEFAULT_BREAKERS
132+
133+
# 1. Construct the token sequence for this row
134+
# Get prompt tokens (handle potential padding if needed)
135+
# Assuming prompt_token_ids is available and correctly indexed
136+
# Need prompt_lens if prompt_token_ids is padded
137+
prompt_len_attr = getattr(sampling_metadata, 'prompt_lens', None)
138+
current_prompt_len = prompt_len_attr[irow] if prompt_len_attr and irow < len(prompt_len_attr) else None
139+
140+
# Ensure prompt_token_ids covers this row
141+
if irow >= sampling_metadata.prompt_token_ids.shape[0]:
142+
continue # Skip if prompt data isn't available for this row
143+
prompt_tensor_row = sampling_metadata.prompt_token_ids[irow]
144+
145+
if current_prompt_len is not None:
146+
current_prompt_tokens = prompt_tensor_row[:current_prompt_len].tolist()
147+
else:
148+
# If prompt_lens is not available, we cannot reliably determine
149+
# the prompt sequence length. Log a warning or raise an error,
150+
# or potentially skip DRY for this row if appropriate.
151+
# For now, let's skip if length is unknown.
152+
# Consider adding logging: logger.warning("prompt_lens not available...")
153+
continue # Skip DRY for this row if prompt length is unknown
154+
155+
# Get output tokens for this row
156+
# Ensure output_token_ids covers this row
157+
if irow >= len(sampling_metadata.output_token_ids):
158+
continue # Skip if output data isn't available
159+
current_output_tokens = sampling_metadata.output_token_ids[irow]
160+
token_seq_list = current_prompt_tokens + current_output_tokens
161+
162+
# 2. Apply range limit
163+
if dry_range > 0 and len(token_seq_list) > dry_range:
164+
token_seq_list = token_seq_list[-dry_range:]
165+
166+
seq_len = len(token_seq_list)
167+
if seq_len < 2:
168+
continue # Need at least 2 tokens
169+
170+
last_token = token_seq_list[-1]
171+
if last_token in breakers:
172+
continue
173+
174+
# Convert to tensor for efficient processing
175+
token_seq_tensor = torch.tensor(token_seq_list, dtype=torch.long, device=device)
176+
177+
# 3. Build break mask on device
178+
break_mask = torch.zeros(seq_len, dtype=torch.bool, device=device)
179+
if breakers:
180+
breaker_tensor = torch.tensor(list(breakers), dtype=torch.long, device=device)
181+
# Use broadcasting for efficiency
182+
break_mask = torch.any(token_seq_tensor.unsqueeze(1) == breaker_tensor, dim=1)
183+
184+
185+
# 4. Determine max n-gram length possible from the end
186+
max_ngram = 0
187+
for offset in range(1, min(seq_len, _DRY_MAX_NGRAM + 1)):
188+
check_idx = -offset - 1
189+
if check_idx < -seq_len: break
190+
if break_mask[check_idx]: break
191+
max_ngram = offset
192+
193+
if max_ngram < allowed_length:
194+
continue
195+
196+
# 5. Find previous occurrences of last_token
197+
endpoint_indices = (token_seq_tensor == last_token).nonzero(as_tuple=True)[0]
198+
if len(endpoint_indices) < 2: continue
199+
200+
endpoint_indices = endpoint_indices[endpoint_indices != seq_len - 1]
201+
if len(endpoint_indices) == 0: continue
202+
203+
if len(endpoint_indices) > _DRY_MAX_OCCURRENCES:
204+
endpoint_indices = endpoint_indices[-_DRY_MAX_OCCURRENCES:]
205+
206+
# 6. Calculate match lengths for potential next tokens
207+
ngram_lens = torch.zeros(vocab_size, dtype=torch.int32, device=device)
208+
found_early_exit_match = False
209+
210+
for idx_tensor in reversed(endpoint_indices):
211+
idx = idx_tensor.item()
212+
match_len = 0
213+
for unwind in range(1, max_ngram + 1):
214+
current_idx = idx - unwind
215+
history_idx = seq_len - 1 - unwind
216+
if current_idx < 0: break
217+
# Check breaks using the precomputed mask
218+
if break_mask[current_idx] or break_mask[history_idx]: break
219+
if token_seq_tensor[current_idx] != token_seq_tensor[history_idx]: break
220+
match_len = unwind
221+
222+
if match_len >= allowed_length: # Match length must meet minimum
223+
next_tok_idx = idx + 1
224+
if next_tok_idx < seq_len:
225+
next_tok = token_seq_tensor[next_tok_idx].item()
226+
# Use match_len as the length of the *matched* sequence
227+
new_len = match_len
228+
current_max = ngram_lens[next_tok].item()
229+
ngram_lens[next_tok] = max(current_max, new_len)
230+
if new_len >= _DRY_EARLY_EXIT_MATCH_LEN:
231+
found_early_exit_match = True
232+
233+
if found_early_exit_match: break
234+
235+
# 7. Apply penalty to logits for this row
236+
penalty_mask = ngram_lens > 0
237+
if penalty_mask.any():
238+
match_lengths_for_penalty = ngram_lens[penalty_mask]
239+
# Clamp exponent >= 0
240+
exponents = (match_lengths_for_penalty.float() - allowed_length).clamp_(min=0.0)
241+
scales = base ** exponents
242+
logits[irow, penalty_mask] -= multiplier * scales
243+
# --- End of DRY logic for row ---
244+
245+
return logits

vllm/v1/sample/sampler.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from vllm.v1.sample.metadata import SamplingMetadata
99
from vllm.v1.sample.ops.bad_words import apply_bad_words
1010
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
11-
apply_min_token_penalties)
11+
apply_min_token_penalties,
12+
apply_dry)
1213
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
1314

1415
_SAMPLING_EPS = 1e-5
@@ -184,10 +185,13 @@ def apply_penalties(
184185
logits: torch.Tensor,
185186
sampling_metadata: SamplingMetadata,
186187
) -> torch.Tensor:
188+
# Apply min tokens penalty first (modifies logits in-place)
187189
if sampling_metadata.min_tokens:
188190
apply_min_token_penalties(logits,
189191
sampling_metadata.output_token_ids,
190192
sampling_metadata.min_tokens)
193+
194+
# Apply standard penalties (Freq, Presence, Repetition)
191195
if not sampling_metadata.no_penalties:
192196
assert sampling_metadata.prompt_token_ids is not None
193197
logits = apply_all_penalties(
@@ -196,8 +200,13 @@ def apply_penalties(
196200
sampling_metadata.presence_penalties,
197201
sampling_metadata.frequency_penalties,
198202
sampling_metadata.repetition_penalties,
199-
sampling_metadata.output_token_ids,
203+
sampling_metadata.output_token_ids, # Passed as list of lists
200204
)
205+
206+
# Apply DRY penalty using the metadata object
207+
# The check for whether DRY is active is now inside apply_dry
208+
logits = apply_dry(logits, sampling_metadata)
209+
201210
return logits
202211

203212
def apply_min_p(

0 commit comments

Comments
 (0)