Skip to content

Commit eeec9e3

Browse files
[Frontend] Separate pooling APIs in offline inference (#11129)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent f93bf2b commit eeec9e3

21 files changed

+659
-294
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,14 @@ steps:
181181
commands:
182182
- VLLM_USE_V1=1 pytest -v -s v1
183183

184-
- label: Examples Test # 15min
184+
- label: Examples Test # 25min
185185
working_dir: "/vllm-workspace/examples"
186186
#mirror_hardwares: [amd]
187187
source_file_dependencies:
188188
- vllm/entrypoints
189189
- examples/
190190
commands:
191-
- pip install awscli tensorizer # for llava example and tensorizer test
191+
- pip install tensorizer # for tensorizer test
192192
- python3 offline_inference.py
193193
- python3 cpu_offload.py
194194
- python3 offline_inference_chat.py
@@ -198,6 +198,9 @@ steps:
198198
- python3 offline_inference_vision_language_multi_image.py
199199
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
200200
- python3 offline_inference_encoder_decoder.py
201+
- python3 offline_inference_classification.py
202+
- python3 offline_inference_embedding.py
203+
- python3 offline_inference_scoring.py
201204
- python3 offline_profile.py --model facebook/opt-125m
202205

203206
- label: Prefix Caching Test # 9min

docs/source/models/pooling_models.rst

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Pooling Models
66
vLLM also supports pooling models, including embedding, reranking and reward models.
77

88
In vLLM, pooling models implement the :class:`~vllm.model_executor.models.VllmModelForPooling` interface.
9-
These models use a :class:`~vllm.model_executor.layers.Pooler` to aggregate the final hidden states of the input
9+
These models use a :class:`~vllm.model_executor.layers.Pooler` to extract the final hidden states of the input
1010
before returning them.
1111

1212
.. note::
@@ -45,20 +45,48 @@ which takes priority over both the model's and Sentence Transformers's defaults.
4545
^^^^^^^^^^^^^^
4646

4747
The :class:`~vllm.LLM.encode` method is available to all pooling models in vLLM.
48-
It returns the aggregated hidden states directly.
48+
It returns the extracted hidden states directly, which is useful for reward models.
49+
50+
.. code-block:: python
51+
52+
llm = LLM(model="Qwen/Qwen2.5-Math-RM-72B", task="reward")
53+
output, = llm.encode("Hello, my name is")
54+
55+
data = output.outputs.data
56+
print(f"Prompt: {prompt!r} | Data: {data!r}")
57+
58+
``LLM.embed``
59+
^^^^^^^^^^^^^
60+
61+
The :class:`~vllm.LLM.embed` method outputs an embedding vector for each prompt.
62+
It is primarily designed for embedding models.
4963

5064
.. code-block:: python
5165
5266
llm = LLM(model="intfloat/e5-mistral-7b-instruct", task="embed")
53-
outputs = llm.encode("Hello, my name is")
67+
output, = llm.embed("Hello, my name is")
5468
55-
outputs = model.encode(prompts)
56-
for output in outputs:
57-
embeddings = output.outputs.embedding
58-
print(f"Prompt: {prompt!r}, Embeddings (size={len(embeddings)}: {embeddings!r}")
69+
embeds = output.outputs.embedding
70+
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
5971
6072
A code example can be found in `examples/offline_inference_embedding.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_embedding.py>`_.
6173

74+
``LLM.classify``
75+
^^^^^^^^^^^^^^^^
76+
77+
The :class:`~vllm.LLM.classify` method outputs a probability vector for each prompt.
78+
It is primarily designed for classification models.
79+
80+
.. code-block:: python
81+
82+
llm = LLM(model="jason9693/Qwen2.5-1.5B-apeach", task="classify")
83+
output, = llm.classify("Hello, my name is")
84+
85+
probs = output.outputs.probs
86+
print(f"Class Probabilities: {probs!r} (size={len(probs)})")
87+
88+
A code example can be found in `examples/offline_inference_classification.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_classification.py>`_.
89+
6290
``LLM.score``
6391
^^^^^^^^^^^^^
6492

@@ -71,7 +99,16 @@ These types of models serve as rerankers between candidate query-document pairs
7199
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.
72100
To handle RAG at a higher level, you should use integration frameworks such as `LangChain <https://github.com/langchain-ai/langchain>`_.
73101

74-
You can use `these tests <https://github.com/vllm-project/vllm/blob/main/tests/models/embedding/language/test_scoring.py>`_ as reference.
102+
.. code-block:: python
103+
104+
llm = LLM(model="BAAI/bge-reranker-v2-m3", task="score")
105+
output, = llm.score("What is the capital of France?",
106+
"The capital of Brazil is Brasilia.")
107+
108+
score = output.outputs.score
109+
print(f"Score: {score}")
110+
111+
A code example can be found in `examples/offline_inference_scoring.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_scoring.py>`_.
75112

76113
Online Inference
77114
----------------
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from vllm import LLM
2+
3+
# Sample prompts.
4+
prompts = [
5+
"Hello, my name is",
6+
"The president of the United States is",
7+
"The capital of France is",
8+
"The future of AI is",
9+
]
10+
11+
# Create an LLM.
12+
# You should pass task="classify" for classification models
13+
model = LLM(
14+
model="jason9693/Qwen2.5-1.5B-apeach",
15+
task="classify",
16+
enforce_eager=True,
17+
)
18+
19+
# Generate logits. The output is a list of ClassificationRequestOutputs.
20+
outputs = model.classify(prompts)
21+
22+
# Print the outputs.
23+
for prompt, output in zip(prompts, outputs):
24+
probs = output.outputs.probs
25+
probs_trimmed = ((str(probs[:16])[:-1] +
26+
", ...]") if len(probs) > 16 else probs)
27+
print(f"Prompt: {prompt!r} | "
28+
f"Class Probabilities: {probs_trimmed} (size={len(probs)})")

examples/offline_inference_embedding.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,20 @@
99
]
1010

1111
# Create an LLM.
12+
# You should pass task="embed" for embedding models
1213
model = LLM(
1314
model="intfloat/e5-mistral-7b-instruct",
14-
task="embed", # You should pass task="embed" for embedding models
15+
task="embed",
1516
enforce_eager=True,
1617
)
1718

18-
# Generate embedding. The output is a list of PoolingRequestOutputs.
19-
outputs = model.encode(prompts)
19+
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
20+
outputs = model.embed(prompts)
21+
2022
# Print the outputs.
21-
for output in outputs:
22-
print(output.outputs.embedding) # list of 4096 floats
23+
for prompt, output in zip(prompts, outputs):
24+
embeds = output.outputs.embedding
25+
embeds_trimmed = ((str(embeds[:16])[:-1] +
26+
", ...]") if len(embeds) > 16 else embeds)
27+
print(f"Prompt: {prompt!r} | "
28+
f"Embeddings: {embeds_trimmed} (size={len(embeds)})")

examples/offline_inference_scoring.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from vllm import LLM
2+
3+
# Sample prompts.
4+
text_1 = "What is the capital of France?"
5+
texts_2 = [
6+
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
7+
]
8+
9+
# Create an LLM.
10+
# You should pass task="score" for cross-encoder models
11+
model = LLM(
12+
model="BAAI/bge-reranker-v2-m3",
13+
task="score",
14+
enforce_eager=True,
15+
)
16+
17+
# Generate scores. The output is a list of ScoringRequestOutputs.
18+
outputs = model.score(text_1, texts_2)
19+
20+
# Print the outputs.
21+
for text_2, output in zip(texts_2, outputs):
22+
score = output.outputs.score
23+
print(f"Pair: {[text_1, text_2]!r} | Score: {score}")

examples/offline_inference_vision_language_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def run_encode(model: str, modality: QueryModality):
133133
if req_data.image is not None:
134134
mm_data["image"] = req_data.image
135135

136-
outputs = req_data.llm.encode({
136+
outputs = req_data.llm.embed({
137137
"prompt": req_data.prompt,
138138
"multi_modal_data": mm_data,
139139
})

tests/conftest.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -719,14 +719,6 @@ def get_inputs(
719719

720720
return inputs
721721

722-
def classify(self, prompts: List[str]) -> List[str]:
723-
req_outputs = self.model.encode(prompts)
724-
outputs = []
725-
for req_output in req_outputs:
726-
embedding = req_output.outputs.embedding
727-
outputs.append(embedding)
728-
return outputs
729-
730722
def generate(
731723
self,
732724
prompts: List[str],
@@ -897,6 +889,10 @@ def generate_beam_search(
897889
returned_outputs.append((token_ids, texts))
898890
return returned_outputs
899891

892+
def classify(self, prompts: List[str]) -> List[List[float]]:
893+
req_outputs = self.model.classify(prompts)
894+
return [req_output.outputs.probs for req_output in req_outputs]
895+
900896
def encode(
901897
self,
902898
prompts: List[str],
@@ -909,16 +905,16 @@ def encode(
909905
videos=videos,
910906
audios=audios)
911907

912-
req_outputs = self.model.encode(inputs)
908+
req_outputs = self.model.embed(inputs)
913909
return [req_output.outputs.embedding for req_output in req_outputs]
914910

915911
def score(
916912
self,
917913
text_1: Union[str, List[str]],
918914
text_2: Union[str, List[str]],
919-
) -> List[List[float]]:
915+
) -> List[float]:
920916
req_outputs = self.model.score(text_1, text_2)
921-
return [req_output.outputs.embedding for req_output in req_outputs]
917+
return [req_output.outputs.score for req_output in req_outputs]
922918

923919
def __enter__(self):
924920
return self

tests/entrypoints/openai/test_score.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
3939
assert score.id is not None
4040
assert score.data is not None
4141
assert len(score.data) == 2
42-
assert score.data[0].score[0] <= 0.01
43-
assert score.data[1].score[0] >= 0.9
42+
assert score.data[0].score <= 0.01
43+
assert score.data[1].score >= 0.9
4444

4545

4646
@pytest.mark.asyncio
@@ -67,8 +67,8 @@ async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
6767
assert score.id is not None
6868
assert score.data is not None
6969
assert len(score.data) == 2
70-
assert score.data[0].score[0] <= 0.01
71-
assert score.data[1].score[0] >= 0.9
70+
assert score.data[0].score <= 0.01
71+
assert score.data[1].score >= 0.9
7272

7373

7474
@pytest.mark.asyncio
@@ -90,4 +90,4 @@ async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
9090
assert score.id is not None
9191
assert score.data is not None
9292
assert len(score.data) == 1
93-
assert score.data[0].score[0] >= 0.9
93+
assert score.data[0].score >= 0.9

tests/models/embedding/language/test_scoring.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str):
4242
assert len(vllm_outputs) == 1
4343
assert len(hf_outputs) == 1
4444

45-
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
45+
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
4646

4747

4848
@pytest.mark.parametrize("dtype", ["half"])
@@ -63,8 +63,8 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
6363
assert len(vllm_outputs) == 2
6464
assert len(hf_outputs) == 2
6565

66-
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
67-
assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01)
66+
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
67+
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)
6868

6969

7070
@pytest.mark.parametrize("dtype", ["half"])
@@ -85,5 +85,5 @@ def test_llm_N_to_N(vllm_runner, hf_runner, model_name, dtype: str):
8585
assert len(vllm_outputs) == 2
8686
assert len(hf_outputs) == 2
8787

88-
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
89-
assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01)
88+
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
89+
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)

tests/models/test_oot_registration.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from vllm import LLM, PoolingParams, SamplingParams
5+
from vllm import LLM, SamplingParams
66
from vllm.assets.image import ImageAsset
77

88
from ..utils import fork_new_process_for_each_test
@@ -36,9 +36,8 @@ def test_oot_registration_text_generation(dummy_opt_path):
3636
def test_oot_registration_embedding(dummy_gemma2_embedding_path):
3737
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
3838
prompts = ["Hello, my name is", "The text does not matter"]
39-
sampling_params = PoolingParams()
4039
llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
41-
outputs = llm.encode(prompts, sampling_params)
40+
outputs = llm.embed(prompts)
4241

4342
for output in outputs:
4443
assert all(v == 0 for v in output.outputs.embedding)

vllm/__init__.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
from vllm.executor.ray_utils import initialize_ray_cluster
88
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
99
from vllm.model_executor.models import ModelRegistry
10-
from vllm.outputs import (CompletionOutput, PoolingOutput,
11-
PoolingRequestOutput, RequestOutput)
10+
from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput,
11+
CompletionOutput, EmbeddingOutput,
12+
EmbeddingRequestOutput, PoolingOutput,
13+
PoolingRequestOutput, RequestOutput, ScoringOutput,
14+
ScoringRequestOutput)
1215
from vllm.pooling_params import PoolingParams
1316
from vllm.sampling_params import SamplingParams
1417

@@ -27,33 +30,16 @@
2730
"CompletionOutput",
2831
"PoolingOutput",
2932
"PoolingRequestOutput",
33+
"EmbeddingOutput",
34+
"EmbeddingRequestOutput",
35+
"ClassificationOutput",
36+
"ClassificationRequestOutput",
37+
"ScoringOutput",
38+
"ScoringRequestOutput",
3039
"LLMEngine",
3140
"EngineArgs",
3241
"AsyncLLMEngine",
3342
"AsyncEngineArgs",
3443
"initialize_ray_cluster",
3544
"PoolingParams",
3645
]
37-
38-
39-
def __getattr__(name: str):
40-
import warnings
41-
42-
if name == "EmbeddingOutput":
43-
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
44-
"The original name will be removed in an upcoming version.")
45-
46-
warnings.warn(DeprecationWarning(msg), stacklevel=2)
47-
48-
return PoolingOutput
49-
50-
if name == "EmbeddingRequestOutput":
51-
msg = ("EmbeddingRequestOutput has been renamed to "
52-
"PoolingRequestOutput. "
53-
"The original name will be removed in an upcoming version.")
54-
55-
warnings.warn(DeprecationWarning(msg), stacklevel=2)
56-
57-
return PoolingRequestOutput
58-
59-
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

0 commit comments

Comments
 (0)