Skip to content

Commit 3da1426

Browse files
K-MisteleGWS0428
authored andcommitted
[Frontend] Rerank API (Jina- and Cohere-compatible API) (vllm-project#12376)
Signed-off-by: Kyle Mistele <[email protected]>
1 parent 9172ffb commit 3da1426

File tree

9 files changed

+552
-11
lines changed

9 files changed

+552
-11
lines changed

docs/source/serving/openai_compatible_server.md

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ In addition, we have the following custom APIs:
5050
- Applicable to all [pooling models](../models/pooling_models.md).
5151
- [Score API](#score-api) (`/score`)
5252
- Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`).
53+
- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`)
54+
- Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/)
55+
- Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank)
56+
- Jina and Cohere's APIs are very similar; Jina's includes extra information in the rerank endpoint's response.
57+
- Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`).
5358

5459
(chat-template)=
5560

@@ -473,3 +478,90 @@ The following extra parameters are supported:
473478
:start-after: begin-score-extra-params
474479
:end-before: end-score-extra-params
475480
```
481+
482+
(rerank-api)=
483+
484+
### Re-rank API
485+
486+
Our Re-rank API applies a cross-encoder model to predict relevant scores between a single query, and
487+
each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on
488+
a scale of 0 to 1.
489+
490+
You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
491+
492+
The rerank endpoints support popular re-rank models such as `BAAI/bge-reranker-base` and other models supporting the
493+
`score` task. Additionally, `/rerank`, `/v1/rerank`, and `/v2/rerank`
494+
endpoints are compatible with both [Jina AI's re-rank API interface](https://jina.ai/reranker/) and
495+
[Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with
496+
popular open-source tools.
497+
498+
Code example: <gh-file:examples/online_serving/jinaai_rerank_client.py>
499+
500+
#### Example Request
501+
502+
Note that the `top_n` request parameter is optional and will default to the length of the `documents` field.
503+
Result documents will be sorted by relevance, and the `index` property can be used to determine original order.
504+
505+
Request:
506+
507+
```bash
508+
curl -X 'POST' \
509+
'http://127.0.0.1:8000/v1/rerank' \
510+
-H 'accept: application/json' \
511+
-H 'Content-Type: application/json' \
512+
-d '{
513+
"model": "BAAI/bge-reranker-base",
514+
"query": "What is the capital of France?",
515+
"documents": [
516+
"The capital of Brazil is Brasilia.",
517+
"The capital of France is Paris.",
518+
"Horses and cows are both animals"
519+
]
520+
}'
521+
```
522+
523+
Response:
524+
525+
```bash
526+
{
527+
"id": "rerank-fae51b2b664d4ed38f5969b612edff77",
528+
"model": "BAAI/bge-reranker-base",
529+
"usage": {
530+
"total_tokens": 56
531+
},
532+
"results": [
533+
{
534+
"index": 1,
535+
"document": {
536+
"text": "The capital of France is Paris."
537+
},
538+
"relevance_score": 0.99853515625
539+
},
540+
{
541+
"index": 0,
542+
"document": {
543+
"text": "The capital of Brazil is Brasilia."
544+
},
545+
"relevance_score": 0.0005860328674316406
546+
}
547+
]
548+
}
549+
```
550+
551+
#### Extra parameters
552+
553+
The following [pooling parameters](#pooling-params) are supported.
554+
555+
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
556+
:language: python
557+
:start-after: begin-rerank-pooling-params
558+
:end-before: end-rerank-pooling-params
559+
```
560+
561+
The following extra parameters are supported:
562+
563+
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
564+
:language: python
565+
:start-after: begin-rerank-extra-params
566+
:end-before: end-rerank-extra-params
567+
```
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""
2+
Example of using the OpenAI entrypoint's rerank API which is compatible with
3+
the Cohere SDK: https://github.com/cohere-ai/cohere-python
4+
5+
run: vllm serve BAAI/bge-reranker-base
6+
"""
7+
import cohere
8+
9+
# cohere v1 client
10+
co = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key")
11+
rerank_v1_result = co.rerank(
12+
model="BAAI/bge-reranker-base",
13+
query="What is the capital of France?",
14+
documents=[
15+
"The capital of France is Paris", "Reranking is fun!",
16+
"vLLM is an open-source framework for fast AI serving"
17+
])
18+
19+
print(rerank_v1_result)
20+
21+
# or the v2
22+
co2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000")
23+
24+
v2_rerank_result = co2.rerank(
25+
model="BAAI/bge-reranker-base",
26+
query="What is the capital of France?",
27+
documents=[
28+
"The capital of France is Paris", "Reranking is fun!",
29+
"vLLM is an open-source framework for fast AI serving"
30+
])
31+
32+
print(v2_rerank_result)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""
2+
Example of using the OpenAI entrypoint's rerank API which is compatible with
3+
Jina and Cohere https://jina.ai/reranker
4+
5+
run: vllm serve BAAI/bge-reranker-base
6+
"""
7+
import json
8+
9+
import requests
10+
11+
url = "http://127.0.0.1:8000/rerank"
12+
13+
headers = {"accept": "application/json", "Content-Type": "application/json"}
14+
15+
data = {
16+
"model":
17+
"BAAI/bge-reranker-base",
18+
"query":
19+
"What is the capital of France?",
20+
"documents": [
21+
"The capital of Brazil is Brasilia.",
22+
"The capital of France is Paris.", "Horses and cows are both animals"
23+
]
24+
}
25+
response = requests.post(url, headers=headers, json=data)
26+
27+
# Check the response
28+
if response.status_code == 200:
29+
print("Request successful!")
30+
print(json.dumps(response.json(), indent=2))
31+
else:
32+
print(f"Request failed with status code: {response.status_code}")
33+
print(response.text)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import pytest
2+
import requests
3+
4+
from vllm.entrypoints.openai.protocol import RerankResponse
5+
6+
from ...utils import RemoteOpenAIServer
7+
8+
MODEL_NAME = "BAAI/bge-reranker-base"
9+
10+
11+
@pytest.fixture(scope="module")
12+
def server():
13+
args = ["--enforce-eager", "--max-model-len", "100"]
14+
15+
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
16+
yield remote_server
17+
18+
19+
@pytest.mark.asyncio
20+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
21+
def test_rerank_texts(server: RemoteOpenAIServer, model_name: str):
22+
query = "What is the capital of France?"
23+
documents = [
24+
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
25+
]
26+
27+
rerank_response = requests.post(server.url_for("rerank"),
28+
json={
29+
"model": model_name,
30+
"query": query,
31+
"documents": documents,
32+
})
33+
rerank_response.raise_for_status()
34+
rerank = RerankResponse.model_validate(rerank_response.json())
35+
36+
assert rerank.id is not None
37+
assert rerank.results is not None
38+
assert len(rerank.results) == 2
39+
assert rerank.results[0].relevance_score >= 0.9
40+
assert rerank.results[1].relevance_score <= 0.01
41+
42+
43+
@pytest.mark.asyncio
44+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
45+
def test_top_n(server: RemoteOpenAIServer, model_name: str):
46+
query = "What is the capital of France?"
47+
documents = [
48+
"The capital of Brazil is Brasilia.",
49+
"The capital of France is Paris.", "Cross-encoder models are neat"
50+
]
51+
52+
rerank_response = requests.post(server.url_for("rerank"),
53+
json={
54+
"model": model_name,
55+
"query": query,
56+
"documents": documents,
57+
"top_n": 2
58+
})
59+
rerank_response.raise_for_status()
60+
rerank = RerankResponse.model_validate(rerank_response.json())
61+
62+
assert rerank.id is not None
63+
assert rerank.results is not None
64+
assert len(rerank.results) == 2
65+
assert rerank.results[0].relevance_score >= 0.9
66+
assert rerank.results[1].relevance_score <= 0.01
67+
68+
69+
@pytest.mark.asyncio
70+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
71+
def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
72+
73+
query = "What is the capital of France?" * 100
74+
documents = [
75+
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
76+
]
77+
78+
rerank_response = requests.post(server.url_for("rerank"),
79+
json={
80+
"model": model_name,
81+
"query": query,
82+
"documents": documents
83+
})
84+
assert rerank_response.status_code == 400
85+
# Assert just a small fragments of the response
86+
assert "Please reduce the length of the input." in \
87+
rerank_response.text

tests/entrypoints/openai/test_score.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,7 @@
1010

1111
@pytest.fixture(scope="module")
1212
def server():
13-
args = [
14-
"--enforce-eager",
15-
# Will be used on tests to compare prompt input length
16-
"--max-model-len",
17-
"100"
18-
]
13+
args = ["--enforce-eager", "--max-model-len", "100"]
1914

2015
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
2116
yield remote_server

vllm/entrypoints/openai/api_server.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
PoolingChatRequest,
5757
PoolingCompletionRequest,
5858
PoolingRequest, PoolingResponse,
59+
RerankRequest, RerankResponse,
5960
ScoreRequest, ScoreResponse,
6061
TokenizeRequest,
6162
TokenizeResponse,
@@ -68,6 +69,7 @@
6869
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
6970
OpenAIServingModels)
7071
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
72+
from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
7173
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
7274
from vllm.entrypoints.openai.serving_tokenization import (
7375
OpenAIServingTokenization)
@@ -306,6 +308,10 @@ def score(request: Request) -> Optional[OpenAIServingScores]:
306308
return request.app.state.openai_serving_scores
307309

308310

311+
def rerank(request: Request) -> Optional[JinaAIServingRerank]:
312+
return request.app.state.jinaai_serving_reranking
313+
314+
309315
def tokenization(request: Request) -> OpenAIServingTokenization:
310316
return request.app.state.openai_serving_tokenization
311317

@@ -502,6 +508,40 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
502508
return await create_score(request, raw_request)
503509

504510

511+
@router.post("/rerank")
512+
@with_cancellation
513+
async def do_rerank(request: RerankRequest, raw_request: Request):
514+
handler = rerank(raw_request)
515+
if handler is None:
516+
return base(raw_request).create_error_response(
517+
message="The model does not support Rerank (Score) API")
518+
generator = await handler.do_rerank(request, raw_request)
519+
if isinstance(generator, ErrorResponse):
520+
return JSONResponse(content=generator.model_dump(),
521+
status_code=generator.code)
522+
elif isinstance(generator, RerankResponse):
523+
return JSONResponse(content=generator.model_dump())
524+
525+
assert_never(generator)
526+
527+
528+
@router.post("/v1/rerank")
529+
@with_cancellation
530+
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
531+
logger.warning(
532+
"To indicate that the rerank API is not part of the standard OpenAI"
533+
" API, we have located it at `/rerank`. Please update your client"
534+
"accordingly. (Note: Conforms to JinaAI rerank API)")
535+
536+
return await do_rerank(request, raw_request)
537+
538+
539+
@router.post("/v2/rerank")
540+
@with_cancellation
541+
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
542+
return await do_rerank(request, raw_request)
543+
544+
505545
TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
506546
"generate": {
507547
"messages": (ChatCompletionRequest, create_chat_completion),
@@ -512,7 +552,10 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
512552
"default": (EmbeddingCompletionRequest, create_embedding),
513553
},
514554
"score": {
515-
"default": (ScoreRequest, create_score),
555+
"default": (RerankRequest, do_rerank)
556+
},
557+
"rerank": {
558+
"default": (RerankRequest, do_rerank)
516559
},
517560
"reward": {
518561
"messages": (PoolingChatRequest, create_pooling),
@@ -759,6 +802,12 @@ async def init_app_state(
759802
state.openai_serving_models,
760803
request_logger=request_logger
761804
) if model_config.task == "score" else None
805+
state.jinaai_serving_reranking = JinaAIServingRerank(
806+
engine_client,
807+
model_config,
808+
state.openai_serving_models,
809+
request_logger=request_logger
810+
) if model_config.task == "score" else None
762811
state.openai_serving_tokenization = OpenAIServingTokenization(
763812
engine_client,
764813
model_config,

0 commit comments

Comments
 (0)