|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 |
|
3 | 3 | import torch
|
| 4 | +from typing import List, Set, Optional, Dict, Any |
| 5 | +from vllm.v1.sample.metadata import SamplingMetadata |
4 | 6 |
|
5 | 7 | from vllm.model_executor.layers.utils import apply_penalties
|
6 | 8 | from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
@@ -56,3 +58,188 @@ def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
|
56 | 58 | pin_memory=is_pin_memory_available(),
|
57 | 59 | )
|
58 | 60 | return output_tokens_tensor.to(device, non_blocking=True)
|
| 61 | + |
| 62 | + |
| 63 | +# Constants for DRY |
| 64 | +_DRY_MAX_NGRAM = 12 |
| 65 | +_DRY_MAX_OCCURRENCES = 8 |
| 66 | +_DRY_EARLY_EXIT_MATCH_LEN = 8 |
| 67 | + |
| 68 | +# Default DRY parameter values |
| 69 | +_DRY_DEFAULT_MULTIPLIER = 0.0 |
| 70 | +_DRY_DEFAULT_BASE = 1.0 |
| 71 | +_DRY_DEFAULT_ALLOWED_LEN = 3 |
| 72 | +_DRY_DEFAULT_RANGE = 1500 |
| 73 | +_DRY_DEFAULT_BREAKERS: Set[int] = set() |
| 74 | + |
| 75 | + |
| 76 | +def apply_dry( |
| 77 | + logits: torch.Tensor, |
| 78 | + sampling_metadata: SamplingMetadata, |
| 79 | +) -> torch.Tensor: |
| 80 | + """ |
| 81 | + Applies DRY (Don't Repeat Yourself) penalty to logits based on parameters |
| 82 | + found in sampling_metadata.extra_data for each request. |
| 83 | +
|
| 84 | + Modifies logits in-place or returns a modified tensor. |
| 85 | +
|
| 86 | + Expected keys in extra_data[irow] (if DRY is active): |
| 87 | + - 'dry_multiplier' (float) |
| 88 | + - 'dry_base' (float) |
| 89 | + - 'dry_allowed_len' (int) |
| 90 | + - 'dry_range' (int) |
| 91 | + - 'dry_breakers' (List[int] or Set[int]) |
| 92 | + """ |
| 93 | + batch_size, vocab_size = logits.shape |
| 94 | + device = logits.device |
| 95 | + |
| 96 | + # Assume extra_data is a list of dicts, one per request |
| 97 | + # Or potentially accessed via methods on sampling_metadata |
| 98 | + # Adjust access pattern if sampling_metadata structure differs |
| 99 | + if not hasattr(sampling_metadata, 'extra_data') or sampling_metadata.extra_data is None: |
| 100 | + # If no extra_data field exists or is None, cannot apply DRY |
| 101 | + return logits |
| 102 | + |
| 103 | + # Check if any request might have DRY enabled (basic check) |
| 104 | + # More robust check would involve iterating through extra_data first |
| 105 | + has_potential_dry = any( |
| 106 | + data and data.get('dry_multiplier', _DRY_DEFAULT_MULTIPLIER) > 0 |
| 107 | + for data in sampling_metadata.extra_data |
| 108 | + ) |
| 109 | + if not has_potential_dry: |
| 110 | + return logits |
| 111 | + |
| 112 | + # --- Iterate through each request in the batch --- |
| 113 | + for irow in range(batch_size): |
| 114 | + # Ensure sampling_metadata has data for this row index |
| 115 | + if irow >= len(sampling_metadata.extra_data): |
| 116 | + # If metadata doesn't cover this row (shouldn't happen in normal flow), skip |
| 117 | + continue |
| 118 | + extra_data = sampling_metadata.extra_data[irow] |
| 119 | + if not extra_data: |
| 120 | + continue |
| 121 | + |
| 122 | + # Get DRY parameters for this row, using defaults if missing |
| 123 | + multiplier = float(extra_data.get('dry_multiplier', _DRY_DEFAULT_MULTIPLIER)) |
| 124 | + if multiplier <= 0.0: |
| 125 | + continue # DRY not active for this row |
| 126 | + |
| 127 | + base = float(extra_data.get('dry_base', _DRY_DEFAULT_BASE)) |
| 128 | + allowed_length = int(extra_data.get('dry_allowed_len', _DRY_DEFAULT_ALLOWED_LEN)) |
| 129 | + dry_range = int(extra_data.get('dry_range', _DRY_DEFAULT_RANGE)) |
| 130 | + breakers_input = extra_data.get('dry_breakers', _DRY_DEFAULT_BREAKERS) |
| 131 | + breakers = set(breakers_input) if breakers_input else _DRY_DEFAULT_BREAKERS |
| 132 | + |
| 133 | + # 1. Construct the token sequence for this row |
| 134 | + # Get prompt tokens (handle potential padding if needed) |
| 135 | + # Assuming prompt_token_ids is available and correctly indexed |
| 136 | + # Need prompt_lens if prompt_token_ids is padded |
| 137 | + prompt_len_attr = getattr(sampling_metadata, 'prompt_lens', None) |
| 138 | + current_prompt_len = prompt_len_attr[irow] if prompt_len_attr and irow < len(prompt_len_attr) else None |
| 139 | + |
| 140 | + # Ensure prompt_token_ids covers this row |
| 141 | + if irow >= sampling_metadata.prompt_token_ids.shape[0]: |
| 142 | + continue # Skip if prompt data isn't available for this row |
| 143 | + prompt_tensor_row = sampling_metadata.prompt_token_ids[irow] |
| 144 | + |
| 145 | + if current_prompt_len is not None: |
| 146 | + current_prompt_tokens = prompt_tensor_row[:current_prompt_len].tolist() |
| 147 | + else: |
| 148 | + # If prompt_lens is not available, we cannot reliably determine |
| 149 | + # the prompt sequence length. Log a warning or raise an error, |
| 150 | + # or potentially skip DRY for this row if appropriate. |
| 151 | + # For now, let's skip if length is unknown. |
| 152 | + # Consider adding logging: logger.warning("prompt_lens not available...") |
| 153 | + continue # Skip DRY for this row if prompt length is unknown |
| 154 | + |
| 155 | + # Get output tokens for this row |
| 156 | + # Ensure output_token_ids covers this row |
| 157 | + if irow >= len(sampling_metadata.output_token_ids): |
| 158 | + continue # Skip if output data isn't available |
| 159 | + current_output_tokens = sampling_metadata.output_token_ids[irow] |
| 160 | + token_seq_list = current_prompt_tokens + current_output_tokens |
| 161 | + |
| 162 | + # 2. Apply range limit |
| 163 | + if dry_range > 0 and len(token_seq_list) > dry_range: |
| 164 | + token_seq_list = token_seq_list[-dry_range:] |
| 165 | + |
| 166 | + seq_len = len(token_seq_list) |
| 167 | + if seq_len < 2: |
| 168 | + continue # Need at least 2 tokens |
| 169 | + |
| 170 | + last_token = token_seq_list[-1] |
| 171 | + if last_token in breakers: |
| 172 | + continue |
| 173 | + |
| 174 | + # Convert to tensor for efficient processing |
| 175 | + token_seq_tensor = torch.tensor(token_seq_list, dtype=torch.long, device=device) |
| 176 | + |
| 177 | + # 3. Build break mask on device |
| 178 | + break_mask = torch.zeros(seq_len, dtype=torch.bool, device=device) |
| 179 | + if breakers: |
| 180 | + breaker_tensor = torch.tensor(list(breakers), dtype=torch.long, device=device) |
| 181 | + # Use broadcasting for efficiency |
| 182 | + break_mask = torch.any(token_seq_tensor.unsqueeze(1) == breaker_tensor, dim=1) |
| 183 | + |
| 184 | + |
| 185 | + # 4. Determine max n-gram length possible from the end |
| 186 | + max_ngram = 0 |
| 187 | + for offset in range(1, min(seq_len, _DRY_MAX_NGRAM + 1)): |
| 188 | + check_idx = -offset - 1 |
| 189 | + if check_idx < -seq_len: break |
| 190 | + if break_mask[check_idx]: break |
| 191 | + max_ngram = offset |
| 192 | + |
| 193 | + if max_ngram < allowed_length: |
| 194 | + continue |
| 195 | + |
| 196 | + # 5. Find previous occurrences of last_token |
| 197 | + endpoint_indices = (token_seq_tensor == last_token).nonzero(as_tuple=True)[0] |
| 198 | + if len(endpoint_indices) < 2: continue |
| 199 | + |
| 200 | + endpoint_indices = endpoint_indices[endpoint_indices != seq_len - 1] |
| 201 | + if len(endpoint_indices) == 0: continue |
| 202 | + |
| 203 | + if len(endpoint_indices) > _DRY_MAX_OCCURRENCES: |
| 204 | + endpoint_indices = endpoint_indices[-_DRY_MAX_OCCURRENCES:] |
| 205 | + |
| 206 | + # 6. Calculate match lengths for potential next tokens |
| 207 | + ngram_lens = torch.zeros(vocab_size, dtype=torch.int32, device=device) |
| 208 | + found_early_exit_match = False |
| 209 | + |
| 210 | + for idx_tensor in reversed(endpoint_indices): |
| 211 | + idx = idx_tensor.item() |
| 212 | + match_len = 0 |
| 213 | + for unwind in range(1, max_ngram + 1): |
| 214 | + current_idx = idx - unwind |
| 215 | + history_idx = seq_len - 1 - unwind |
| 216 | + if current_idx < 0: break |
| 217 | + # Check breaks using the precomputed mask |
| 218 | + if break_mask[current_idx] or break_mask[history_idx]: break |
| 219 | + if token_seq_tensor[current_idx] != token_seq_tensor[history_idx]: break |
| 220 | + match_len = unwind |
| 221 | + |
| 222 | + if match_len >= allowed_length: # Match length must meet minimum |
| 223 | + next_tok_idx = idx + 1 |
| 224 | + if next_tok_idx < seq_len: |
| 225 | + next_tok = token_seq_tensor[next_tok_idx].item() |
| 226 | + # Use match_len as the length of the *matched* sequence |
| 227 | + new_len = match_len |
| 228 | + current_max = ngram_lens[next_tok].item() |
| 229 | + ngram_lens[next_tok] = max(current_max, new_len) |
| 230 | + if new_len >= _DRY_EARLY_EXIT_MATCH_LEN: |
| 231 | + found_early_exit_match = True |
| 232 | + |
| 233 | + if found_early_exit_match: break |
| 234 | + |
| 235 | + # 7. Apply penalty to logits for this row |
| 236 | + penalty_mask = ngram_lens > 0 |
| 237 | + if penalty_mask.any(): |
| 238 | + match_lengths_for_penalty = ngram_lens[penalty_mask] |
| 239 | + # Clamp exponent >= 0 |
| 240 | + exponents = (match_lengths_for_penalty.float() - allowed_length).clamp_(min=0.0) |
| 241 | + scales = base ** exponents |
| 242 | + logits[irow, penalty_mask] -= multiplier * scales |
| 243 | + # --- End of DRY logic for row --- |
| 244 | + |
| 245 | + return logits |
0 commit comments