|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +import dataclasses |
| 3 | +from abc import ABC, abstractmethod |
| 4 | +from collections.abc import Sequence |
| 5 | +from typing import Optional |
| 6 | + |
| 7 | +import torch |
| 8 | +from torch._prims_common import DeviceLikeType |
| 9 | + |
| 10 | +from vllm import SamplingParams |
| 11 | + |
| 12 | + |
| 13 | +@dataclasses.dataclass |
| 14 | +class BatchUpdate: |
| 15 | + # The current number of requests in the batch. |
| 16 | + batch_size: int |
| 17 | + # Batch indices of any removed requests. |
| 18 | + removed: Sequence[int] = () |
| 19 | + # (from, to) batch indices of any requests |
| 20 | + # moved within the batch. |
| 21 | + moved: Sequence[tuple[int, int]] = () |
| 22 | + # (index, params, output_tok_ids) for new |
| 23 | + # requests added to the batch. |
| 24 | + added: Sequence[tuple[int, SamplingParams, list[int]]] = () |
| 25 | + |
| 26 | + |
| 27 | +class LogitsProcessor(ABC): |
| 28 | + |
| 29 | + @abstractmethod |
| 30 | + def apply(self, logits: torch.Tensor) -> torch.Tensor: |
| 31 | + raise NotImplementedError |
| 32 | + |
| 33 | + @abstractmethod |
| 34 | + def update_states( |
| 35 | + self, |
| 36 | + batch_update: Optional[BatchUpdate] = None, |
| 37 | + ) -> None: |
| 38 | + """Called when there are new output tokens, prior |
| 39 | + to each forward pass. |
| 40 | +
|
| 41 | + Args: |
| 42 | + batch_update is non-None iff there have been |
| 43 | + changes to the batch makeup. |
| 44 | + """ |
| 45 | + raise NotImplementedError |
| 46 | + |
| 47 | + |
| 48 | +###### ----- LogitsProcessor impls below here |
| 49 | + |
| 50 | + |
| 51 | +class MinPLogitsProcessor(LogitsProcessor): |
| 52 | + |
| 53 | + def __init__(self, max_num_reqs: int, pin_memory: bool, |
| 54 | + device: DeviceLikeType): |
| 55 | + self.min_p_count: int = 0 |
| 56 | + |
| 57 | + self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), |
| 58 | + dtype=torch.float32, |
| 59 | + device="cpu", |
| 60 | + pin_memory=pin_memory) |
| 61 | + self.min_p_cpu = self.min_p_cpu_tensor.numpy() |
| 62 | + # Pre-allocated device tensor |
| 63 | + self.min_p_gpu: torch.Tensor = torch.empty((max_num_reqs, ), |
| 64 | + dtype=torch.float32, |
| 65 | + device=device) |
| 66 | + # Current slice of the device tensor |
| 67 | + self.min_p: torch.Tensor = self.min_p_gpu[:0] |
| 68 | + |
| 69 | + def update_states(self, batch_update: Optional[BatchUpdate] = None): |
| 70 | + if not batch_update: |
| 71 | + return |
| 72 | + |
| 73 | + needs_update = False |
| 74 | + if self.min_p_count: |
| 75 | + # Process removed and moved requests. |
| 76 | + for index in batch_update.removed: |
| 77 | + if self.min_p_cpu[index]: |
| 78 | + self.min_p_count -= 1 |
| 79 | + needs_update = True |
| 80 | + |
| 81 | + for from_index, to_index in batch_update.moved: |
| 82 | + min_p = self.min_p_cpu[from_index] |
| 83 | + self.min_p_cpu[to_index] = min_p |
| 84 | + if min_p: |
| 85 | + needs_update = True |
| 86 | + |
| 87 | + # Process added requests. |
| 88 | + for index, sampling_params, _ in batch_update.added: |
| 89 | + min_p = sampling_params.min_p |
| 90 | + self.min_p_cpu[index] = min_p |
| 91 | + if min_p: |
| 92 | + self.min_p_count += 1 |
| 93 | + needs_update = True |
| 94 | + |
| 95 | + # Update tensors if needed. |
| 96 | + size = batch_update.batch_size |
| 97 | + if self.min_p_count and (needs_update or self.min_p.shape[0] != size): |
| 98 | + |
| 99 | + self.min_p = self.min_p_gpu[:size] |
| 100 | + self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True) |
| 101 | + self.min_p.unsqueeze_(1) |
| 102 | + |
| 103 | + def apply(self, logits: torch.Tensor) -> torch.Tensor: |
| 104 | + if not self.min_p_count: |
| 105 | + return logits |
| 106 | + |
| 107 | + # Convert logits to probability distribution |
| 108 | + probability_values = torch.nn.functional.softmax(logits, dim=-1) |
| 109 | + # Calculate maximum probabilities per sequence |
| 110 | + max_probabilities = torch.amax(probability_values, |
| 111 | + dim=-1, |
| 112 | + keepdim=True) |
| 113 | + # Adjust min_p |
| 114 | + adjusted_min_p = max_probabilities.mul_(self.min_p) |
| 115 | + # Identify valid tokens using threshold comparison |
| 116 | + invalid_token_mask = probability_values < adjusted_min_p |
| 117 | + # Apply mask using boolean indexing |
| 118 | + logits[invalid_token_mask] = -float('inf') |
| 119 | + return logits |
| 120 | + |
| 121 | + |
| 122 | +class LogitBiasLogitsProcessor(LogitsProcessor): |
| 123 | + |
| 124 | + def __init__(self, pin_memory: bool, device: torch.device): |
| 125 | + self.biases: dict[int, dict[int, float]] = {} |
| 126 | + self.device = device |
| 127 | + self.pin_memory = pin_memory |
| 128 | + |
| 129 | + self.bias_tensor: torch.Tensor = torch.tensor(()) |
| 130 | + self.logits_slice: tuple[torch.Tensor, torch.Tensor] = (torch.tensor( |
| 131 | + ()), torch.tensor(())) |
| 132 | + |
| 133 | + def update_states(self, batch_update: Optional[BatchUpdate] = None): |
| 134 | + if not batch_update: |
| 135 | + return |
| 136 | + |
| 137 | + needs_update = False |
| 138 | + if self.biases: |
| 139 | + # Process removed and moved requests. |
| 140 | + for index in batch_update.removed: |
| 141 | + if self.biases.pop(index, None): |
| 142 | + needs_update = True |
| 143 | + |
| 144 | + for from_index, to_index in batch_update.moved: |
| 145 | + if entry := self.biases.pop(from_index, None): |
| 146 | + self.biases[to_index] = entry |
| 147 | + needs_update = True |
| 148 | + |
| 149 | + # Process added requests. |
| 150 | + for index, sampling_params, _ in batch_update.added: |
| 151 | + if lb := sampling_params.logit_bias: |
| 152 | + self.biases[index] = lb |
| 153 | + needs_update = True |
| 154 | + |
| 155 | + # Update tensors if needed. |
| 156 | + if self.biases and needs_update: |
| 157 | + reqs, tok_ids, biases = [], [], [] |
| 158 | + for req, lb in self.biases.items(): |
| 159 | + reqs.extend([req] * len(lb)) |
| 160 | + tok_ids.extend(lb.keys()) |
| 161 | + biases.extend(lb.values()) |
| 162 | + |
| 163 | + self.bias_tensor = self._tensor(biases, torch.float32) |
| 164 | + self.logits_slice = (self._tensor(reqs, torch.int32), |
| 165 | + self._tensor(tok_ids, torch.int32)) |
| 166 | + |
| 167 | + def _tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: |
| 168 | + return (torch.tensor(data, |
| 169 | + device="cpu", |
| 170 | + dtype=dtype, |
| 171 | + pin_memory=self.pin_memory).to(device=self.device, |
| 172 | + non_blocking=True)) |
| 173 | + |
| 174 | + def apply(self, logits: torch.Tensor) -> torch.Tensor: |
| 175 | + if self.biases: |
| 176 | + logits[self.logits_slice] += self.bias_tensor |
| 177 | + return logits |
| 178 | + |
| 179 | + |
| 180 | +class MinTokensLogitsProcessor(LogitsProcessor): |
| 181 | + |
| 182 | + def __init__(self, pin_memory: bool, device: torch.device): |
| 183 | + # index -> (min_toks, output_token_ids, stop_token_ids) |
| 184 | + self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {} |
| 185 | + self.device = device |
| 186 | + self.pin_memory = pin_memory |
| 187 | + |
| 188 | + self.logits_slice: tuple[torch.Tensor, torch.Tensor] = (torch.tensor( |
| 189 | + ()), torch.tensor(())) |
| 190 | + |
| 191 | + def update_states(self, batch_update: Optional[BatchUpdate] = None): |
| 192 | + needs_update = False |
| 193 | + if batch_update: |
| 194 | + if self.min_toks: |
| 195 | + # Process removed and moved requests. |
| 196 | + for index in batch_update.removed: |
| 197 | + if self.min_toks.pop(index, None): |
| 198 | + needs_update = True |
| 199 | + |
| 200 | + for from_index, to_index in batch_update.moved: |
| 201 | + if entry := self.min_toks.pop(from_index, None): |
| 202 | + self.min_toks[to_index] = entry |
| 203 | + needs_update = True |
| 204 | + |
| 205 | + # Process added requests. |
| 206 | + for index, sampling_params, output_tok_ids in batch_update.added: |
| 207 | + if ((min_tokens := sampling_params.min_tokens) |
| 208 | + and len(output_tok_ids) < min_tokens): |
| 209 | + self.min_toks[index] = (min_tokens, output_tok_ids, |
| 210 | + sampling_params.all_stop_token_ids) |
| 211 | + needs_update = True |
| 212 | + |
| 213 | + if self.min_toks: |
| 214 | + # Check for any requests that have attained their min tokens. |
| 215 | + to_remove = tuple(index for index, (min_toks, out_tok_ids, |
| 216 | + _) in self.min_toks.items() |
| 217 | + if len(out_tok_ids) >= min_toks) |
| 218 | + if to_remove: |
| 219 | + needs_update = True |
| 220 | + for index in to_remove: |
| 221 | + del self.min_toks[index] |
| 222 | + |
| 223 | + # Update tensors if needed. |
| 224 | + if needs_update and self.min_toks: |
| 225 | + reqs: list[int] = [] |
| 226 | + tok_ids: list[int] = [] |
| 227 | + for req, (_, _, stop_tok_ids) in self.min_toks.items(): |
| 228 | + reqs.extend([req] * len(stop_tok_ids)) |
| 229 | + tok_ids.extend(stop_tok_ids) |
| 230 | + |
| 231 | + self.logits_slice = (self._tensor(reqs, torch.int32), |
| 232 | + self._tensor(tok_ids, torch.int32)) |
| 233 | + |
| 234 | + def _tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: |
| 235 | + return (torch.tensor(data, |
| 236 | + device="cpu", |
| 237 | + dtype=dtype, |
| 238 | + pin_memory=self.pin_memory).to(device=self.device, |
| 239 | + non_blocking=True)) |
| 240 | + |
| 241 | + def apply(self, logits: torch.Tensor) -> torch.Tensor: |
| 242 | + if self.min_toks: |
| 243 | + logits[self.logits_slice] = -float("inf") |
| 244 | + return logits |
0 commit comments