Skip to content

[WIP]: DRY sampling #16695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 201 additions & 0 deletions vllm/v1/sample/ops/penalties.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0

import torch
from typing import List, Optional, Dict, Any
from vllm.v1.sample.metadata import SamplingMetadata

from vllm.model_executor.layers.utils import apply_penalties
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
Expand Down Expand Up @@ -56,3 +58,202 @@
pin_memory=is_pin_memory_available(),
)
return output_tokens_tensor.to(device, non_blocking=True)


# Constants for DRY
_DRY_MAX_NGRAM = 12
_DRY_MAX_OCCURRENCES = 8
_DRY_EARLY_EXIT_MATCH_LEN = 8

# Default DRY parameter values
_DRY_DEFAULT_MULTIPLIER = 0.0
_DRY_DEFAULT_BASE = 1.0
_DRY_DEFAULT_ALLOWED_LEN = 3
_DRY_DEFAULT_RANGE = 1500
_DRY_DEFAULT_BREAKERS: set[int] = set()


def apply_dry(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
"""
Applies DRY (Don't Repeat Yourself) penalty to logits based on parameters
found in sampling_metadata.extra_data for each request.

Modifies logits in-place or returns a modified tensor.

Expected keys in extra_data[irow] (if DRY is active):
- 'dry_multiplier' (float)
- 'dry_base' (float)
- 'dry_allowed_len' (int)
- 'dry_range' (int)
- 'dry_breakers' (List[int] or Set[int])
"""
batch_size, vocab_size = logits.shape
device = logits.device

# Assume extra_data is a list of dicts, one per request
# Or potentially accessed via methods on sampling_metadata
# Adjust access pattern if sampling_metadata structure differs
if not hasattr(sampling_metadata, 'extra_data') or sampling_metadata.extra_data is None:
# If no extra_data field exists or is None, cannot apply DRY
return logits

# Check if any request might have DRY enabled (basic check)
# More robust check would involve iterating through extra_data first
has_potential_dry = any(
data and data.get('dry_multiplier', _DRY_DEFAULT_MULTIPLIER) > 0
for data in sampling_metadata.extra_data
)
if not has_potential_dry:
return logits
Comment on lines +99 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check should ideally be moved to a gpu_input_batch property like penalties.

Copy link
Author

@0xymoro 0xymoro Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a longer discussion within the #bounty-program long thread in Slack - but I think Nick was saying he wanted it to be part of extra args. I'm not quite familiar with the codebase itself, but DRY essentially works like other penalties and has a few params. It'd be great if it's a first class sampler where things like dry_multiplier, etc are just passed into the engine like repetition_penalty is. I'm fairly confused about the passing of extra args myself since some part of it reads to me like it's not integrated fully yet? But the Slack thread and latest messages were what I was going by.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@0xymoro apologies for the confusion. As @NickLucche says, the state that only depends on the current set of requests and their parameters, should go in InputBatch here (you can see the state associated with other sampling params there).

Then, add logic in the add_request method to update this state based on your own request parameters that you can retrieve from request.sampling_params.extra_args.

Logic should also be added to the remove_request(), swap_states() and condense() methods to remove/reorder the requests within the preallocated state, and to _make_sampling_metadata() to update the SamplingMetadata based on the current (changed) state in the input batch.

All of this is what #13360 aims to abstract/encapsulate into the LogitProcessor interface.


# --- Iterate through each request in the batch ---
for irow in range(batch_size):
# Ensure sampling_metadata has data for this row index
if irow >= len(sampling_metadata.extra_data):

Check failure on line 115 in vllm/v1/sample/ops/penalties.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/sample/ops/penalties.py:115:81: E501 Line too long (88 > 80)
# If metadata doesn't cover this row (shouldn't happen in normal flow), skip
continue
extra_data = sampling_metadata.extra_data[irow]
if not extra_data:
continue

# Get DRY parameters for this row, using defaults if missing
multiplier = float(extra_data.get('dry_multiplier', _DRY_DEFAULT_MULTIPLIER))
if multiplier <= 0.0:
continue # DRY not active for this row

base = float(extra_data.get('dry_base', _DRY_DEFAULT_BASE))
allowed_length = int(extra_data.get('dry_allowed_len', _DRY_DEFAULT_ALLOWED_LEN))
dry_range = int(extra_data.get('dry_range', _DRY_DEFAULT_RANGE))
breakers_input = extra_data.get('dry_breakers', _DRY_DEFAULT_BREAKERS)
breakers = set(breakers_input) if breakers_input else _DRY_DEFAULT_BREAKERS

# 1. Construct the token sequence for this row
# Get prompt tokens (handle potential padding if needed)
# Assuming prompt_token_ids is available and correctly indexed
# Need prompt_lens if prompt_token_ids is padded
prompt_len_attr = getattr(sampling_metadata, 'prompt_lens', None)
current_prompt_len = (prompt_len_attr[irow]
if prompt_len_attr and irow < len(prompt_len_attr)
else None)

# Ensure prompt_token_ids covers this row
if irow >= sampling_metadata.prompt_token_ids.shape[0]:
continue # Skip if prompt data isn't available for this row

Check failure on line 144 in vllm/v1/sample/ops/penalties.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Optional[Any]" has no attribute "shape" [union-attr]

Check failure on line 144 in vllm/v1/sample/ops/penalties.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Optional[Any]" has no attribute "shape" [union-attr]

Check failure on line 144 in vllm/v1/sample/ops/penalties.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Optional[Any]" has no attribute "shape" [union-attr]

Check failure on line 144 in vllm/v1/sample/ops/penalties.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Optional[Any]" has no attribute "shape" [union-attr]
prompt_tensor_row = sampling_metadata.prompt_token_ids[irow]

Check failure on line 146 in vllm/v1/sample/ops/penalties.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type "Optional[Any]" is not indexable [index]

Check failure on line 146 in vllm/v1/sample/ops/penalties.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type "Optional[Any]" is not indexable [index]

Check failure on line 146 in vllm/v1/sample/ops/penalties.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type "Optional[Any]" is not indexable [index]

Check failure on line 146 in vllm/v1/sample/ops/penalties.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type "Optional[Any]" is not indexable [index]
if current_prompt_len is not None:
current_prompt_tokens = prompt_tensor_row[:current_prompt_len].tolist()
else:
# If prompt_lens is not available, we cannot reliably determine
# the prompt sequence length. Log a warning or raise an error,
# or potentially skip DRY for this row if appropriate.
# For now, let's skip if length is unknown.
# Consider adding logging: logger.warning("prompt_lens not available...")
continue # Skip DRY for this row if prompt length is unknown

# Get output tokens for this row

Check failure on line 157 in vllm/v1/sample/ops/penalties.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/sample/ops/penalties.py:157:81: E501 Line too long (85 > 80)
# Ensure output_token_ids covers this row
if irow >= len(sampling_metadata.output_token_ids):
continue # Skip if output data isn't available
current_output_tokens = sampling_metadata.output_token_ids[irow]
token_seq_list = current_prompt_tokens + current_output_tokens

# 2. Apply range limit
if dry_range > 0 and len(token_seq_list) > dry_range:
token_seq_list = token_seq_list[-dry_range:]

seq_len = len(token_seq_list)
if seq_len < 2:
continue # Need at least 2 tokens

last_token = token_seq_list[-1]
if last_token in breakers:
continue

# Convert to tensor for efficient processing
token_seq_tensor = torch.tensor(
token_seq_list,
dtype=torch.long,
device=device
)

# 3. Build break mask on device
break_mask = torch.zeros(seq_len, dtype=torch.bool, device=device)
if breakers:
breaker_tensor = torch.tensor(list(breakers), dtype=torch.long, device=device)
# Use broadcasting for efficiency
break_mask = torch.any(token_seq_tensor.unsqueeze(1) == breaker_tensor, dim=1)


# 4. Determine max n-gram length possible from the end
max_ngram = 0
for offset in range(1, min(seq_len, _DRY_MAX_NGRAM + 1)):
check_idx = -offset - 1
if check_idx < -seq_len:
break
if break_mask[check_idx]:
break
max_ngram = offset

if max_ngram < allowed_length:
continue

# 5. Find previous occurrences of last_token
endpoint_indices = (token_seq_tensor == last_token).nonzero(as_tuple=True)[0]
if len(endpoint_indices) < 2:
continue

endpoint_indices = endpoint_indices[endpoint_indices != seq_len - 1]
if len(endpoint_indices) == 0:
continue

if len(endpoint_indices) > _DRY_MAX_OCCURRENCES:
endpoint_indices = endpoint_indices[-_DRY_MAX_OCCURRENCES:]

# 6. Calculate match lengths for potential next tokens
ngram_lens = torch.zeros(vocab_size, dtype=torch.int32, device=device)
found_early_exit_match = False

for idx_tensor in reversed(endpoint_indices):
idx = idx_tensor.item()
match_len = 0
for unwind in range(1, max_ngram + 1):
current_idx = idx - unwind
history_idx = seq_len - 1 - unwind
if current_idx < 0:
break
# Check breaks using the precomputed mask
if break_mask[current_idx] or break_mask[history_idx]:
break
if token_seq_tensor[current_idx] != token_seq_tensor[history_idx]:
break
match_len = unwind

if match_len >= allowed_length: # Match length must meet minimum
next_tok_idx = idx + 1
if next_tok_idx < seq_len:
next_tok = token_seq_tensor[next_tok_idx].item()
# Use match_len as the length of the *matched* sequence
new_len = match_len
current_max = ngram_lens[next_tok].item()
ngram_lens[next_tok] = max(current_max, new_len)
if new_len >= _DRY_EARLY_EXIT_MATCH_LEN:
found_early_exit_match = True

if found_early_exit_match:
break

# 7. Apply penalty to logits for this row
penalty_mask = ngram_lens > 0
if penalty_mask.any():
match_lengths_for_penalty = ngram_lens[penalty_mask]
# Clamp exponent >= 0
exponents = (match_lengths_for_penalty.float() - allowed_length).clamp_(min=0.0)
scales = base ** exponents
logits[irow, penalty_mask] -= multiplier * scales
# --- End of DRY logic for row ---

return logits
13 changes: 11 additions & 2 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
apply_min_token_penalties)
apply_min_token_penalties,
apply_dry)
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler

_SAMPLING_EPS = 1e-5
Expand Down Expand Up @@ -184,10 +185,13 @@ def apply_penalties(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# Apply min tokens penalty first (modifies logits in-place)
if sampling_metadata.min_tokens:
apply_min_token_penalties(logits,
sampling_metadata.output_token_ids,
sampling_metadata.min_tokens)

# Apply standard penalties (Freq, Presence, Repetition)
if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
logits = apply_all_penalties(
Expand All @@ -196,8 +200,13 @@ def apply_penalties(
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids,
sampling_metadata.output_token_ids, # Passed as list of lists
)

# Apply DRY penalty using the metadata object
# The check for whether DRY is active is now inside apply_dry
logits = apply_dry(logits, sampling_metadata)

return logits

def apply_min_p(
Expand Down