Skip to content

Commit 6c11ecf

Browse files
authored
[Bugfix] Validate logit biases to prevent out of vocab ids crashing engine (#16529)
Signed-off-by: Ryan McConville <[email protected]>
1 parent 93e5f3c commit 6c11ecf

File tree

3 files changed

+119
-0
lines changed

3 files changed

+119
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import openai
4+
import pytest
5+
import pytest_asyncio
6+
7+
from vllm.config import ModelConfig
8+
9+
from ...utils import RemoteOpenAIServer
10+
11+
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
12+
13+
14+
def get_vocab_size(model_name):
15+
config = ModelConfig(
16+
model=model_name,
17+
task="auto",
18+
tokenizer=model_name,
19+
tokenizer_mode="auto",
20+
trust_remote_code=False,
21+
seed=0,
22+
dtype="bfloat16",
23+
)
24+
return config.get_vocab_size()
25+
26+
27+
@pytest.fixture(scope="module")
28+
def server():
29+
args = [
30+
"--dtype",
31+
"bfloat16",
32+
"--max-model-len",
33+
"1024",
34+
"--enforce-eager",
35+
]
36+
37+
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
38+
yield remote_server
39+
40+
41+
@pytest_asyncio.fixture
42+
async def client(server):
43+
async with server.get_async_client() as async_client:
44+
yield async_client
45+
46+
47+
@pytest.mark.asyncio
48+
async def test_chat_logit_bias_valid(client):
49+
"""Test that valid logit_bias values are accepted in chat completions."""
50+
vocab_size = get_vocab_size(MODEL_NAME)
51+
valid_token_id = vocab_size - 1
52+
53+
completion = await client.chat.completions.create(
54+
model=MODEL_NAME,
55+
messages=[{
56+
"role": "user",
57+
"content": "Testing valid logit bias"
58+
}],
59+
max_tokens=5,
60+
logit_bias={str(valid_token_id): 1.0},
61+
)
62+
63+
assert completion.choices[0].message.content is not None
64+
65+
66+
@pytest.mark.asyncio
67+
async def test_chat_logit_bias_invalid(client):
68+
"""Test that invalid logit_bias values are rejected in chat completions."""
69+
vocab_size = get_vocab_size(MODEL_NAME)
70+
invalid_token_id = vocab_size + 1
71+
72+
with pytest.raises(openai.BadRequestError) as excinfo:
73+
await client.chat.completions.create(
74+
model=MODEL_NAME,
75+
messages=[{
76+
"role": "user",
77+
"content": "Testing invalid logit bias"
78+
}],
79+
max_tokens=5,
80+
logit_bias={str(invalid_token_id): 1.0},
81+
)
82+
83+
error = excinfo.value
84+
error_message = str(error)
85+
86+
assert error.status_code == 400
87+
assert str(invalid_token_id) in error_message
88+
assert str(vocab_size) in error_message

vllm/v1/engine/processor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def _validate_sampling_params(
7777
params: SamplingParams,
7878
) -> None:
7979
self._validate_structured_output(params)
80+
self._validate_logit_bias(params)
8081

8182
if params.allowed_token_ids is None:
8283
return
@@ -87,6 +88,26 @@ def _validate_sampling_params(
8788
raise ValueError(
8889
"allowed_token_ids contains out-of-vocab token id!")
8990

91+
def _validate_logit_bias(
92+
self,
93+
params: SamplingParams,
94+
) -> None:
95+
"""Validate logit_bias token IDs are within vocabulary range."""
96+
if not params.logit_bias:
97+
return
98+
99+
vocab_size = self.model_config.get_vocab_size()
100+
invalid_token_ids = []
101+
102+
for token_id in params.logit_bias:
103+
if token_id < 0 or token_id >= vocab_size:
104+
invalid_token_ids.append(token_id)
105+
106+
if invalid_token_ids:
107+
raise ValueError(
108+
f"token_id(s) {invalid_token_ids} in logit_bias contain "
109+
f"out-of-vocab token ids. Vocabulary size: {vocab_size}")
110+
90111
def _validate_supported_sampling_params(
91112
self,
92113
params: SamplingParams,

vllm/v1/sample/sampler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,19 @@ def apply_logits_bias(
230230
# TODO(houseroad): this implementation is extremely inefficient.
231231
# One idea is implement this as a PyTorch C++ op, and we may
232232
# even optimize the logit_bias layout.
233+
234+
# Get vocabulary size from logits
235+
vocab_size = logits.shape[-1]
236+
233237
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
234238
if logit_bias:
235239
for token_id, bias in logit_bias.items():
240+
# Check token_id bounds to ensure within vocabulary
241+
if token_id < 0 or token_id >= vocab_size:
242+
raise ValueError(
243+
f"token_id {token_id} in logit_bias contains "
244+
f"out-of-vocab token id. Vocabulary size: "
245+
f"{vocab_size}")
236246
logits[i, token_id] += bias
237247
return logits
238248

0 commit comments

Comments
 (0)