Skip to content

Commit 58fd57f

Browse files
authored
[Bugfix] Fix score api for missing max_model_len validation (#12119)
Signed-off-by: Wallas Santos <[email protected]>
1 parent 87a0c07 commit 58fd57f

File tree

3 files changed

+80
-33
lines changed

3 files changed

+80
-33
lines changed

tests/entrypoints/openai/test_score.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
def server():
1313
args = [
1414
"--enforce-eager",
15+
# Will be used on tests to compare prompt input length
16+
"--max-model-len",
17+
"100"
1518
]
1619

1720
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@@ -20,8 +23,7 @@ def server():
2023

2124
@pytest.mark.asyncio
2225
@pytest.mark.parametrize("model_name", [MODEL_NAME])
23-
async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
24-
model_name: str):
26+
def test_text_1_str_text_2_list(server: RemoteOpenAIServer, model_name: str):
2527
text_1 = "What is the capital of France?"
2628
text_2 = [
2729
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
@@ -45,8 +47,7 @@ async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
4547

4648
@pytest.mark.asyncio
4749
@pytest.mark.parametrize("model_name", [MODEL_NAME])
48-
async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
49-
model_name: str):
50+
def test_text_1_list_text_2_list(server: RemoteOpenAIServer, model_name: str):
5051
text_1 = [
5152
"What is the capital of the United States?",
5253
"What is the capital of France?"
@@ -73,8 +74,7 @@ async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
7374

7475
@pytest.mark.asyncio
7576
@pytest.mark.parametrize("model_name", [MODEL_NAME])
76-
async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
77-
model_name: str):
77+
def test_text_1_str_text_2_str(server: RemoteOpenAIServer, model_name: str):
7878
text_1 = "What is the capital of France?"
7979
text_2 = "The capital of France is Paris."
8080

@@ -91,3 +91,36 @@ async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
9191
assert score.data is not None
9292
assert len(score.data) == 1
9393
assert score.data[0].score >= 0.9
94+
95+
96+
@pytest.mark.asyncio
97+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
98+
def test_score_max_model_len(server: RemoteOpenAIServer, model_name: str):
99+
100+
text_1 = "What is the capital of France?" * 20
101+
text_2 = [
102+
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
103+
]
104+
105+
score_response = requests.post(server.url_for("score"),
106+
json={
107+
"model": model_name,
108+
"text_1": text_1,
109+
"text_2": text_2,
110+
})
111+
assert score_response.status_code == 400
112+
# Assert just a small fragments of the response
113+
assert "Please reduce the length of the input." in \
114+
score_response.text
115+
116+
# Test truncation
117+
score_response = requests.post(server.url_for("score"),
118+
json={
119+
"model": model_name,
120+
"text_1": text_1,
121+
"text_2": text_2,
122+
"truncate_prompt_tokens": 101
123+
})
124+
assert score_response.status_code == 400
125+
assert "Please, select a smaller truncation size." in \
126+
score_response.text

vllm/entrypoints/openai/serving_engine.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,15 +203,19 @@ def _validate_input(
203203
) -> TextTokensPrompt:
204204
token_num = len(input_ids)
205205

206-
# Note: EmbeddingRequest doesn't have max_tokens
207-
if isinstance(request,
208-
(EmbeddingChatRequest, EmbeddingCompletionRequest)):
206+
# Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
207+
if isinstance(
208+
request,
209+
(EmbeddingChatRequest, EmbeddingCompletionRequest, ScoreRequest)):
210+
211+
operation = "score" if isinstance(request, ScoreRequest) \
212+
else "embedding generation"
209213
if token_num > self.max_model_len:
210214
raise ValueError(
211215
f"This model's maximum context length is "
212216
f"{self.max_model_len} tokens. However, you requested "
213-
f"{token_num} tokens in the input for embedding "
214-
f"generation. Please reduce the length of the input.")
217+
f"{token_num} tokens in the input for {operation}. "
218+
f"Please reduce the length of the input.")
215219
return TextTokensPrompt(prompt=input_text,
216220
prompt_token_ids=input_ids)
217221

vllm/entrypoints/openai/serving_score.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -101,35 +101,45 @@ async def create_score(
101101
if not self.model_config.is_cross_encoder:
102102
raise ValueError("Model is not cross encoder.")
103103

104+
if truncate_prompt_tokens is not None and \
105+
truncate_prompt_tokens > self.max_model_len:
106+
raise ValueError(
107+
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
108+
f"is greater than max_model_len ({self.max_model_len})."
109+
f" Please, select a smaller truncation size.")
110+
111+
input_pairs = make_pairs(request.text_1, request.text_2)
112+
for q, t in input_pairs:
113+
request_prompt = f"{q}{tokenizer.sep_token}{t}"
114+
115+
tokenization_kwargs: Dict[str, Any] = {}
116+
if truncate_prompt_tokens is not None:
117+
tokenization_kwargs["truncation"] = True
118+
tokenization_kwargs["max_length"] = truncate_prompt_tokens
119+
120+
tokenize_async = make_async(tokenizer.__call__,
121+
executor=self._tokenizer_executor)
122+
prompt_inputs = await tokenize_async(text=q,
123+
text_pair=t,
124+
**tokenization_kwargs)
125+
126+
input_ids = prompt_inputs["input_ids"]
127+
text_token_prompt = \
128+
self._validate_input(request, input_ids, request_prompt)
129+
engine_prompt = TokensPrompt(
130+
prompt_token_ids=text_token_prompt["prompt_token_ids"],
131+
token_type_ids=prompt_inputs.get("token_type_ids"))
132+
133+
request_prompts.append(request_prompt)
134+
engine_prompts.append(engine_prompt)
135+
104136
except ValueError as e:
105137
logger.exception("Error in preprocessing prompt inputs")
106138
return self.create_error_response(str(e))
107139

108140
# Schedule the request and get the result generator.
109141
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
110142

111-
input_pairs = make_pairs(request.text_1, request.text_2)
112-
113-
for q, t in input_pairs:
114-
request_prompt = f"{q}{tokenizer.sep_token}{t}"
115-
116-
tokenization_kwargs: Dict[str, Any] = {}
117-
if truncate_prompt_tokens is not None:
118-
tokenization_kwargs["truncation"] = True
119-
tokenization_kwargs["max_length"] = truncate_prompt_tokens
120-
121-
tokenize_async = make_async(tokenizer.__call__,
122-
executor=self._tokenizer_executor)
123-
prompt_inputs = await tokenize_async(text=q,
124-
text_pair=t,
125-
**tokenization_kwargs)
126-
engine_prompt = TokensPrompt(
127-
prompt_token_ids=prompt_inputs["input_ids"],
128-
token_type_ids=prompt_inputs.get("token_type_ids"))
129-
130-
request_prompts.append(request_prompt)
131-
engine_prompts.append(engine_prompt)
132-
133143
try:
134144
pooling_params = request.to_pooling_params()
135145

0 commit comments

Comments
 (0)