Skip to content

Commit b504b73

Browse files
committed
[RFC][V1] LogitsProcessor interface
Signed-off-by: Nick Hill <[email protected]>
1 parent 54a66e5 commit b504b73

File tree

9 files changed

+354
-159
lines changed

9 files changed

+354
-159
lines changed

vllm/v1/attention/backends/flash_attn.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""Attention layer with FlashAttention."""
3+
from collections.abc import Sequence
34
from dataclasses import dataclass
45
from typing import TYPE_CHECKING, Any, Optional
56

@@ -279,9 +280,10 @@ class FlashAttentionMetadataBuilder:
279280
def __init__(self, runner: "GPUModelRunner"):
280281
self.runner = runner
281282

282-
def reorder_batch(self, input_batch: "InputBatch",
283-
scheduler_output: "SchedulerOutput") -> bool:
284-
return False
283+
def reorder_batch(
284+
self, input_batch: "InputBatch",
285+
scheduler_output: "SchedulerOutput") -> Sequence[tuple[int, int]]:
286+
return ()
285287

286288
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
287289
common_prefix_len: int):

vllm/v1/attention/backends/mla/common.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@
186186

187187
import functools
188188
from abc import abstractmethod
189+
from collections.abc import Sequence
189190
from dataclasses import dataclass
190191
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
191192

@@ -377,8 +378,11 @@ def __init__(self,
377378
)
378379
self.page_size = self.runner.block_size
379380

380-
def reorder_batch(self, input_batch: "InputBatch",
381-
scheduler_output: "SchedulerOutput") -> bool:
381+
def reorder_batch(
382+
self,
383+
input_batch: "InputBatch",
384+
scheduler_output: "SchedulerOutput",
385+
) -> Sequence[tuple[int, int]]:
382386
# We now want to reorder the batch so that the "decode" requests are and
383387
# the front and the "prefill" requests are at the using the least amount
384388
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
@@ -415,20 +419,25 @@ def reorder_batch(self, input_batch: "InputBatch",
415419
# the above loop
416420
num_decodes = len(decodes)
417421
num_prefills = len(prefills)
418-
first_prefill = 0
419-
modified_batch = False
420422

423+
swaps = []
421424
for i in range(1, min(num_decodes, num_prefills) + 1):
422425
# If the decode is at the "back" of the batch, i, we can swap it
423426
# with the prefill closest to the front of the batch
424-
if decodes[num_decodes - i] >= num_decodes:
425-
input_batch.swap_states(prefills[first_prefill],
426-
decodes[num_decodes - i])
427-
first_prefill += 1
428-
modified_batch = True
429-
else:
427+
if decodes[num_decodes - i] < num_decodes:
430428
break
431429

430+
i1 = prefills[i - 1]
431+
i2 = decodes[num_decodes - i]
432+
input_batch.swap_states(i1, i2)
433+
434+
# Using "move" operation of LogitsProcessors via temporary slot
435+
# currently.
436+
# TODO possibly add more direct swap operation to LPs
437+
swaps.append((i1, input_batch.max_num_reqs))
438+
swaps.append((i2, i1))
439+
swaps.append((input_batch.max_num_reqs, i2))
440+
432441
# Save for next `build` call
433442
# TODO(lucas): this is a bit of a hack, we should probably have a
434443
# better way of doing this
@@ -437,7 +446,7 @@ def reorder_batch(self, input_batch: "InputBatch",
437446
self._num_decode_tokens = num_decode_tokens
438447
self._num_prefill_tokens = num_prefill_tokens
439448

440-
return modified_batch
449+
return swaps
441450

442451
def _build_decode(self, input_positions: torch.Tensor,
443452
block_table: torch.Tensor, seq_lens: torch.Tensor):

vllm/v1/sample/logits_processor.py

+244
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import dataclasses
3+
from abc import ABC, abstractmethod
4+
from collections.abc import Sequence
5+
from typing import Optional
6+
7+
import torch
8+
from torch._prims_common import DeviceLikeType
9+
10+
from vllm import SamplingParams
11+
12+
13+
@dataclasses.dataclass
14+
class BatchUpdate:
15+
# The current number of requests in the batch.
16+
batch_size: int
17+
# Batch indices of any removed requests.
18+
removed: Sequence[int] = ()
19+
# (from, to) batch indices of any requests
20+
# moved within the batch.
21+
moved: Sequence[tuple[int, int]] = ()
22+
# (index, params, output_tok_ids) for new
23+
# requests added to the batch.
24+
added: Sequence[tuple[int, SamplingParams, list[int]]] = ()
25+
26+
27+
class LogitsProcessor(ABC):
28+
29+
@abstractmethod
30+
def apply(self, logits: torch.Tensor) -> torch.Tensor:
31+
raise NotImplementedError
32+
33+
@abstractmethod
34+
def update_states(
35+
self,
36+
batch_update: Optional[BatchUpdate] = None,
37+
) -> None:
38+
"""Called when there are new output tokens, prior
39+
to each forward pass.
40+
41+
Args:
42+
batch_update is non-None iff there have been
43+
changes to the batch makeup.
44+
"""
45+
raise NotImplementedError
46+
47+
48+
###### ----- LogitsProcessor impls below here
49+
50+
51+
class MinPLogitsProcessor(LogitsProcessor):
52+
53+
def __init__(self, max_num_reqs: int, pin_memory: bool,
54+
device: DeviceLikeType):
55+
self.min_p_count: int = 0
56+
57+
self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
58+
dtype=torch.float32,
59+
device="cpu",
60+
pin_memory=pin_memory)
61+
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
62+
# Pre-allocated device tensor
63+
self.min_p_gpu: torch.Tensor = torch.empty((max_num_reqs, ),
64+
dtype=torch.float32,
65+
device=device)
66+
# Current slice of the device tensor
67+
self.min_p: torch.Tensor = self.min_p_gpu[:0]
68+
69+
def update_states(self, batch_update: Optional[BatchUpdate] = None):
70+
if not batch_update:
71+
return
72+
73+
needs_update = False
74+
if self.min_p_count:
75+
# Process removed and moved requests.
76+
for index in batch_update.removed:
77+
if self.min_p_cpu[index]:
78+
self.min_p_count -= 1
79+
needs_update = True
80+
81+
for from_index, to_index in batch_update.moved:
82+
min_p = self.min_p_cpu[from_index]
83+
self.min_p_cpu[to_index] = min_p
84+
if min_p:
85+
needs_update = True
86+
87+
# Process added requests.
88+
for index, sampling_params, _ in batch_update.added:
89+
min_p = sampling_params.min_p
90+
self.min_p_cpu[index] = min_p
91+
if min_p:
92+
self.min_p_count += 1
93+
needs_update = True
94+
95+
# Update tensors if needed.
96+
size = batch_update.batch_size
97+
if self.min_p_count and (needs_update or self.min_p.shape[0] != size):
98+
99+
self.min_p = self.min_p_gpu[:size]
100+
self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True)
101+
self.min_p.unsqueeze_(1)
102+
103+
def apply(self, logits: torch.Tensor) -> torch.Tensor:
104+
if not self.min_p_count:
105+
return logits
106+
107+
# Convert logits to probability distribution
108+
probability_values = torch.nn.functional.softmax(logits, dim=-1)
109+
# Calculate maximum probabilities per sequence
110+
max_probabilities = torch.amax(probability_values,
111+
dim=-1,
112+
keepdim=True)
113+
# Adjust min_p
114+
adjusted_min_p = max_probabilities.mul_(self.min_p)
115+
# Identify valid tokens using threshold comparison
116+
invalid_token_mask = probability_values < adjusted_min_p
117+
# Apply mask using boolean indexing
118+
logits[invalid_token_mask] = -float('inf')
119+
return logits
120+
121+
122+
class LogitBiasLogitsProcessor(LogitsProcessor):
123+
124+
def __init__(self, pin_memory: bool, device: torch.device):
125+
self.biases: dict[int, dict[int, float]] = {}
126+
self.device = device
127+
self.pin_memory = pin_memory
128+
129+
self.bias_tensor: torch.Tensor = torch.tensor(())
130+
self.logits_slice: tuple[torch.Tensor, torch.Tensor] = (torch.tensor(
131+
()), torch.tensor(()))
132+
133+
def update_states(self, batch_update: Optional[BatchUpdate] = None):
134+
if not batch_update:
135+
return
136+
137+
needs_update = False
138+
if self.biases:
139+
# Process removed and moved requests.
140+
for index in batch_update.removed:
141+
if self.biases.pop(index, None):
142+
needs_update = True
143+
144+
for from_index, to_index in batch_update.moved:
145+
if entry := self.biases.pop(from_index, None):
146+
self.biases[to_index] = entry
147+
needs_update = True
148+
149+
# Process added requests.
150+
for index, sampling_params, _ in batch_update.added:
151+
if lb := sampling_params.logit_bias:
152+
self.biases[index] = lb
153+
needs_update = True
154+
155+
# Update tensors if needed.
156+
if self.biases and needs_update:
157+
reqs, tok_ids, biases = [], [], []
158+
for req, lb in self.biases.items():
159+
reqs.extend([req] * len(lb))
160+
tok_ids.extend(lb.keys())
161+
biases.extend(lb.values())
162+
163+
self.bias_tensor = self._tensor(biases, torch.float32)
164+
self.logits_slice = (self._tensor(reqs, torch.int32),
165+
self._tensor(tok_ids, torch.int32))
166+
167+
def _tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
168+
return (torch.tensor(data,
169+
device="cpu",
170+
dtype=dtype,
171+
pin_memory=self.pin_memory).to(device=self.device,
172+
non_blocking=True))
173+
174+
def apply(self, logits: torch.Tensor) -> torch.Tensor:
175+
if self.biases:
176+
logits[self.logits_slice] += self.bias_tensor
177+
return logits
178+
179+
180+
class MinTokensLogitsProcessor(LogitsProcessor):
181+
182+
def __init__(self, pin_memory: bool, device: torch.device):
183+
# index -> (min_toks, output_token_ids, stop_token_ids)
184+
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
185+
self.device = device
186+
self.pin_memory = pin_memory
187+
188+
self.logits_slice: tuple[torch.Tensor, torch.Tensor] = (torch.tensor(
189+
()), torch.tensor(()))
190+
191+
def update_states(self, batch_update: Optional[BatchUpdate] = None):
192+
needs_update = False
193+
if batch_update:
194+
if self.min_toks:
195+
# Process removed and moved requests.
196+
for index in batch_update.removed:
197+
if self.min_toks.pop(index, None):
198+
needs_update = True
199+
200+
for from_index, to_index in batch_update.moved:
201+
if entry := self.min_toks.pop(from_index, None):
202+
self.min_toks[to_index] = entry
203+
needs_update = True
204+
205+
# Process added requests.
206+
for index, sampling_params, output_tok_ids in batch_update.added:
207+
if ((min_tokens := sampling_params.min_tokens)
208+
and len(output_tok_ids) < min_tokens):
209+
self.min_toks[index] = (min_tokens, output_tok_ids,
210+
sampling_params.all_stop_token_ids)
211+
needs_update = True
212+
213+
if self.min_toks:
214+
# Check for any requests that have attained their min tokens.
215+
to_remove = tuple(index for index, (min_toks, out_tok_ids,
216+
_) in self.min_toks.items()
217+
if len(out_tok_ids) >= min_toks)
218+
if to_remove:
219+
needs_update = True
220+
for index in to_remove:
221+
del self.min_toks[index]
222+
223+
# Update tensors if needed.
224+
if needs_update and self.min_toks:
225+
reqs: list[int] = []
226+
tok_ids: list[int] = []
227+
for req, (_, _, stop_tok_ids) in self.min_toks.items():
228+
reqs.extend([req] * len(stop_tok_ids))
229+
tok_ids.extend(stop_tok_ids)
230+
231+
self.logits_slice = (self._tensor(reqs, torch.int32),
232+
self._tensor(tok_ids, torch.int32))
233+
234+
def _tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
235+
return (torch.tensor(data,
236+
device="cpu",
237+
dtype=dtype,
238+
pin_memory=self.pin_memory).to(device=self.device,
239+
non_blocking=True))
240+
241+
def apply(self, logits: torch.Tensor) -> torch.Tensor:
242+
if self.min_toks:
243+
logits[self.logits_slice] = -float("inf")
244+
return logits

vllm/v1/sample/metadata.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import torch
77

8+
from vllm.v1.sample.logits_processor import LogitsProcessor
9+
810

911
@dataclass
1012
class SamplingMetadata:
@@ -15,7 +17,6 @@ class SamplingMetadata:
1517

1618
top_p: Optional[torch.Tensor]
1719
top_k: Optional[torch.Tensor]
18-
min_p: Optional[torch.Tensor]
1920

2021
generators: dict[int, torch.Generator]
2122

@@ -30,14 +31,12 @@ class SamplingMetadata:
3031

3132
output_token_ids: list[list[int]]
3233

33-
# req_index -> (min_tokens, stop_token_ids)
34-
min_tokens: dict[int, tuple[int, set[int]]]
35-
36-
logit_bias: list[Optional[dict[int, float]]]
37-
3834
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
3935
# vocab size).
4036
allowed_token_ids_mask: Optional[torch.Tensor]
4137

4238
# req_index -> bad_words_token_ids
4339
bad_words_token_ids: dict[int, list[list[int]]]
40+
41+
logits_procs: list[LogitsProcessor]
42+
nongreedy_logits_procs: list[LogitsProcessor]

vllm/v1/sample/ops/penalties.py

-16
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,6 @@
66
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
77

88

9-
def apply_min_token_penalties(
10-
logits: torch.Tensor, output_token_ids: list[list[int]],
11-
min_tokens: dict[int, tuple[int, set[int]]]) -> None:
12-
"""
13-
Applies minimum token penalty by setting the logits of the stop tokens
14-
to -inf.
15-
"""
16-
min_tokens_logits_to_penalize: list[tuple[int, int]] = []
17-
for index, (min_token, stop_token_ids) in min_tokens.items():
18-
if len(output_token_ids[index]) < min_token:
19-
for stop_token_id in stop_token_ids:
20-
min_tokens_logits_to_penalize.append((index, stop_token_id))
21-
if min_tokens_logits_to_penalize:
22-
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
23-
24-
259
def apply_all_penalties(
2610
logits: torch.Tensor,
2711
prompt_token_ids: torch.Tensor,

0 commit comments

Comments
 (0)