Skip to content

[RFC][V1] LogitsProcessor interface #13360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with FlashAttention."""
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional

Expand Down Expand Up @@ -279,9 +280,10 @@ class FlashAttentionMetadataBuilder:
def __init__(self, runner: "GPUModelRunner"):
self.runner = runner

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False
def reorder_batch(
self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> Sequence[tuple[int, int]]:
return ()

def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
Expand Down
31 changes: 20 additions & 11 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@

import functools
from abc import abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar

Expand Down Expand Up @@ -377,8 +378,11 @@ def __init__(self,
)
self.page_size = self.runner.block_size

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
def reorder_batch(
self,
input_batch: "InputBatch",
scheduler_output: "SchedulerOutput",
) -> Sequence[tuple[int, int]]:
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
Expand Down Expand Up @@ -415,20 +419,25 @@ def reorder_batch(self, input_batch: "InputBatch",
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
first_prefill = 0
modified_batch = False

swaps = []
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
if decodes[num_decodes - i] >= num_decodes:
input_batch.swap_states(prefills[first_prefill],
decodes[num_decodes - i])
first_prefill += 1
modified_batch = True
else:
if decodes[num_decodes - i] < num_decodes:
break

i1 = prefills[i - 1]
i2 = decodes[num_decodes - i]
input_batch.swap_states(i1, i2)

# Using "move" operation of LogitsProcessors via temporary slot
# currently.
# TODO possibly add more direct swap operation to LPs
swaps.append((i1, input_batch.max_num_reqs))
swaps.append((i2, i1))
swaps.append((input_batch.max_num_reqs, i2))
Copy link
Collaborator

@LucasWilkinson LucasWilkinson Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could maybe consider having input_batch track the swaps (i.e. calls to input_batch.swap_states) in an internal data structure that gets reset with calls to refresh_sampling_metadata, this way the if we do do TODO possibly add more direct swap operation to LPs then all the attn metadata implementations don't need to updated (FlashInfer does swapping to)


# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
Expand All @@ -437,7 +446,7 @@ def reorder_batch(self, input_batch: "InputBatch",
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens

return modified_batch
return swaps

def _build_decode(self, input_positions: torch.Tensor,
block_table: torch.Tensor, seq_lens: torch.Tensor):
Expand Down
244 changes: 244 additions & 0 deletions vllm/v1/sample/logits_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# SPDX-License-Identifier: Apache-2.0
import dataclasses
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Optional

import torch
from torch._prims_common import DeviceLikeType

from vllm import SamplingParams


@dataclasses.dataclass
class BatchUpdate:
# The current number of requests in the batch.
batch_size: int
# Batch indices of any removed requests.
removed: Sequence[int] = ()
# (from, to) batch indices of any requests
# moved within the batch.
moved: Sequence[tuple[int, int]] = ()
# (index, params, output_tok_ids) for new
# requests added to the batch.
added: Sequence[tuple[int, SamplingParams, list[int]]] = ()


class LogitsProcessor(ABC):

@abstractmethod
def apply(self, logits: torch.Tensor) -> torch.Tensor:
raise NotImplementedError

@abstractmethod
def update_states(
self,
batch_update: Optional[BatchUpdate] = None,
) -> None:
"""Called when there are new output tokens, prior
to each forward pass.

Args:
batch_update is non-None iff there have been
changes to the batch makeup.
"""
raise NotImplementedError


###### ----- LogitsProcessor impls below here


class MinPLogitsProcessor(LogitsProcessor):

def __init__(self, max_num_reqs: int, pin_memory: bool,
device: DeviceLikeType):
self.min_p_count: int = 0

self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why having min_p_cpu on top of min_p_cpu_tensor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mirrors the existing implementation, I think the reason is that it's much cheaper to manipulate individual elements in numpy arrays than tensors, so better to do those updates on a numpy view. But we still need the tensor to transfer to GPU.

# Pre-allocated device tensor
self.min_p_gpu: torch.Tensor = torch.empty((max_num_reqs, ),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we call it min_p_device? We may have use case other than GPU, right?

dtype=torch.float32,
device=device)
Comment on lines +63 to +65
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we preallocate for other as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon this could potentially be done but it's not quite a simple because the maximum size could be much larger (e.g. max_batch_size * max_logit_bias_tokens * 3 for logit bias), but they are sparse and so in practice the size would be much smaller, or more likely they wouldn't actually be used at all.

Intuitively, since we are minimizing how often these get updated, and still doing the transfer from cpu async from pinned mem, I would guess we aren't losing much by not preallocating. But I can benchmark and see. If we do it we may want to just allocate to high-water-mark so that if the feature isn't used for example then there will be no allocation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Sounds good. As this is a niche feature, I think we shouldn't spend too much time in it.

# Current slice of the device tensor
self.min_p: torch.Tensor = self.min_p_gpu[:0]

def update_states(self, batch_update: Optional[BatchUpdate] = None):
if not batch_update:
return

needs_update = False
if self.min_p_count:
# Process removed and moved requests.
for index in batch_update.removed:
if self.min_p_cpu[index]:
self.min_p_count -= 1
needs_update = True

for from_index, to_index in batch_update.moved:
min_p = self.min_p_cpu[from_index]
self.min_p_cpu[to_index] = min_p
if min_p:
needs_update = True

# Process added requests.
for index, sampling_params, _ in batch_update.added:
min_p = sampling_params.min_p
self.min_p_cpu[index] = min_p
if min_p:
self.min_p_count += 1
needs_update = True

# Update tensors if needed.
size = batch_update.batch_size
if self.min_p_count and (needs_update or self.min_p.shape[0] != size):

self.min_p = self.min_p_gpu[:size]
self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True)
self.min_p.unsqueeze_(1)

def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.min_p_count:
return logits

# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
Copy link

@DreamGenX DreamGenX Mar 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you can do min p without softmax, which might be faster:
Set logits[i] to -inf when logits[i] < logits_max + ln(min_p) where logits_max is the largest logit.

You can precompute the log(min_p) in update_states

# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values,
dim=-1,
keepdim=True)
# Adjust min_p
adjusted_min_p = max_probabilities.mul_(self.min_p)
# Identify valid tokens using threshold comparison
invalid_token_mask = probability_values < adjusted_min_p
# Apply mask using boolean indexing
logits[invalid_token_mask] = -float('inf')
return logits


class LogitBiasLogitsProcessor(LogitsProcessor):

def __init__(self, pin_memory: bool, device: torch.device):
self.biases: dict[int, dict[int, float]] = {}
self.device = device
self.pin_memory = pin_memory

self.bias_tensor: torch.Tensor = torch.tensor(())
self.logits_slice: tuple[torch.Tensor, torch.Tensor] = (torch.tensor(
()), torch.tensor(()))

def update_states(self, batch_update: Optional[BatchUpdate] = None):
if not batch_update:
return

needs_update = False
if self.biases:
# Process removed and moved requests.
for index in batch_update.removed:
if self.biases.pop(index, None):
needs_update = True

for from_index, to_index in batch_update.moved:
if entry := self.biases.pop(from_index, None):
self.biases[to_index] = entry
needs_update = True

# Process added requests.
for index, sampling_params, _ in batch_update.added:
if lb := sampling_params.logit_bias:
self.biases[index] = lb
needs_update = True

# Update tensors if needed.
if self.biases and needs_update:
reqs, tok_ids, biases = [], [], []
for req, lb in self.biases.items():
reqs.extend([req] * len(lb))
tok_ids.extend(lb.keys())
biases.extend(lb.values())

self.bias_tensor = self._tensor(biases, torch.float32)
self.logits_slice = (self._tensor(reqs, torch.int32),
self._tensor(tok_ids, torch.int32))

def _tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
return (torch.tensor(data,
device="cpu",
dtype=dtype,
pin_memory=self.pin_memory).to(device=self.device,
non_blocking=True))

def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.biases:
logits[self.logits_slice] += self.bias_tensor
return logits


class MinTokensLogitsProcessor(LogitsProcessor):

def __init__(self, pin_memory: bool, device: torch.device):
# index -> (min_toks, output_token_ids, stop_token_ids)
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
self.device = device
self.pin_memory = pin_memory

self.logits_slice: tuple[torch.Tensor, torch.Tensor] = (torch.tensor(
()), torch.tensor(()))

def update_states(self, batch_update: Optional[BatchUpdate] = None):
needs_update = False
if batch_update:
if self.min_toks:
# Process removed and moved requests.
for index in batch_update.removed:
if self.min_toks.pop(index, None):
needs_update = True

for from_index, to_index in batch_update.moved:
if entry := self.min_toks.pop(from_index, None):
self.min_toks[to_index] = entry
needs_update = True

# Process added requests.
for index, sampling_params, output_tok_ids in batch_update.added:
if ((min_tokens := sampling_params.min_tokens)
and len(output_tok_ids) < min_tokens):
self.min_toks[index] = (min_tokens, output_tok_ids,
sampling_params.all_stop_token_ids)
needs_update = True

if self.min_toks:
# Check for any requests that have attained their min tokens.
to_remove = tuple(index for index, (min_toks, out_tok_ids,
_) in self.min_toks.items()
if len(out_tok_ids) >= min_toks)
if to_remove:
needs_update = True
for index in to_remove:
del self.min_toks[index]

# Update tensors if needed.
if needs_update and self.min_toks:
reqs: list[int] = []
tok_ids: list[int] = []
for req, (_, _, stop_tok_ids) in self.min_toks.items():
reqs.extend([req] * len(stop_tok_ids))
tok_ids.extend(stop_tok_ids)

self.logits_slice = (self._tensor(reqs, torch.int32),
self._tensor(tok_ids, torch.int32))

def _tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
return (torch.tensor(data,
device="cpu",
dtype=dtype,
pin_memory=self.pin_memory).to(device=self.device,
non_blocking=True))

def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.min_toks:
logits[self.logits_slice] = -float("inf")
return logits
11 changes: 5 additions & 6 deletions vllm/v1/sample/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import torch

from vllm.v1.sample.logits_processor import LogitsProcessor


@dataclass
class SamplingMetadata:
Expand All @@ -15,7 +17,6 @@ class SamplingMetadata:

top_p: Optional[torch.Tensor]
top_k: Optional[torch.Tensor]
min_p: Optional[torch.Tensor]

generators: dict[int, torch.Generator]

Expand All @@ -30,14 +31,12 @@ class SamplingMetadata:

output_token_ids: list[list[int]]

# req_index -> (min_tokens, stop_token_ids)
min_tokens: dict[int, tuple[int, set[int]]]

logit_bias: list[Optional[dict[int, float]]]

# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size).
allowed_token_ids_mask: Optional[torch.Tensor]

# req_index -> bad_words_token_ids
bad_words_token_ids: dict[int, list[list[int]]]

logits_procs: list[LogitsProcessor]
nongreedy_logits_procs: list[LogitsProcessor]
16 changes: 0 additions & 16 deletions vllm/v1/sample/ops/penalties.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,6 @@
from vllm.utils import is_pin_memory_available, make_tensor_with_pad


def apply_min_token_penalties(
logits: torch.Tensor, output_token_ids: list[list[int]],
min_tokens: dict[int, tuple[int, set[int]]]) -> None:
"""
Applies minimum token penalty by setting the logits of the stop tokens
to -inf.
"""
min_tokens_logits_to_penalize: list[tuple[int, int]] = []
for index, (min_token, stop_token_ids) in min_tokens.items():
if len(output_token_ids[index]) < min_token:
for stop_token_id in stop_token_ids:
min_tokens_logits_to_penalize.append((index, stop_token_id))
if min_tokens_logits_to_penalize:
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")


def apply_all_penalties(
logits: torch.Tensor,
prompt_token_ids: torch.Tensor,
Expand Down
Loading
Loading