Skip to content

Commit d2f058e

Browse files
[Misc] Rename embedding classes to pooling (#10801)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent f877a7d commit d2f058e

25 files changed

+166
-123
lines changed

examples/offline_inference_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
# Create an LLM.
1212
model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True)
13-
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
13+
# Generate embedding. The output is a list of PoolingRequestOutputs.
1414
outputs = model.encode(prompts)
1515
# Print the outputs.
1616
for output in outputs:

tests/entrypoints/llm/test_encode.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55

6-
from vllm import LLM, EmbeddingRequestOutput, PoolingParams
6+
from vllm import LLM, PoolingParams, PoolingRequestOutput
77
from vllm.distributed import cleanup_dist_env_and_memory
88

99
MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
@@ -43,8 +43,8 @@ def llm():
4343
cleanup_dist_env_and_memory()
4444

4545

46-
def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
47-
o2: List[EmbeddingRequestOutput]):
46+
def assert_outputs_equal(o1: List[PoolingRequestOutput],
47+
o2: List[PoolingRequestOutput]):
4848
assert [o.outputs for o in o1] == [o.outputs for o in o2]
4949

5050

tests/models/test_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import torch.cuda
55

6-
from vllm.model_executor.models import (is_embedding_model,
6+
from vllm.model_executor.models import (is_pooling_model,
77
is_text_generation_model,
88
supports_multimodal)
99
from vllm.model_executor.models.adapters import as_embedding_model
@@ -31,7 +31,7 @@ def test_registry_imports(model_arch):
3131

3232
# All vLLM models should be convertible to an embedding model
3333
embed_model = as_embedding_model(model_cls)
34-
assert is_embedding_model(embed_model)
34+
assert is_pooling_model(embed_model)
3535

3636
if model_arch in _MULTIMODAL_MODELS:
3737
assert supports_multimodal(model_cls)

tests/worker/test_model_input.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from vllm.attention.backends.utils import CommonAttentionState
99
from vllm.model_executor import SamplingMetadata
1010
from vllm.model_executor.pooling_metadata import PoolingMetadata
11-
from vllm.worker.embedding_model_runner import (
12-
ModelInputForGPUWithPoolingMetadata)
1311
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
1412
from vllm.worker.multi_step_model_runner import StatefulModelInput
13+
from vllm.worker.pooling_model_runner import (
14+
ModelInputForGPUWithPoolingMetadata)
1515

1616

1717
class MockAttentionBackend(AttentionBackend):

vllm/__init__.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from vllm.executor.ray_utils import initialize_ray_cluster
88
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
99
from vllm.model_executor.models import ModelRegistry
10-
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
11-
EmbeddingRequestOutput, RequestOutput)
10+
from vllm.outputs import (CompletionOutput, PoolingOutput,
11+
PoolingRequestOutput, RequestOutput)
1212
from vllm.pooling_params import PoolingParams
1313
from vllm.sampling_params import SamplingParams
1414

@@ -25,12 +25,35 @@
2525
"SamplingParams",
2626
"RequestOutput",
2727
"CompletionOutput",
28-
"EmbeddingOutput",
29-
"EmbeddingRequestOutput",
28+
"PoolingOutput",
29+
"PoolingRequestOutput",
3030
"LLMEngine",
3131
"EngineArgs",
3232
"AsyncLLMEngine",
3333
"AsyncEngineArgs",
3434
"initialize_ray_cluster",
3535
"PoolingParams",
3636
]
37+
38+
39+
def __getattr__(name: str):
40+
import warnings
41+
42+
if name == "EmbeddingOutput":
43+
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
44+
"The original name will be removed in an upcoming version.")
45+
46+
warnings.warn(DeprecationWarning(msg), stacklevel=2)
47+
48+
return PoolingOutput
49+
50+
if name == "EmbeddingRequestOutput":
51+
msg = ("EmbeddingRequestOutput has been renamed to "
52+
"PoolingRequestOutput. "
53+
"The original name will be removed in an upcoming version.")
54+
55+
warnings.warn(DeprecationWarning(msg), stacklevel=2)
56+
57+
return PoolingRequestOutput
58+
59+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def _resolve_task(
359359
# NOTE: Listed from highest to lowest priority,
360360
# in case the model supports multiple of them
361361
"generate": ModelRegistry.is_text_generation_model(architectures),
362-
"embedding": ModelRegistry.is_embedding_model(architectures),
362+
"embedding": ModelRegistry.is_pooling_model(architectures),
363363
}
364364
supported_tasks_lst: List[_Task] = [
365365
task for task, is_supported in task_support.items() if is_supported

vllm/engine/async_llm_engine.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from vllm.model_executor.guided_decoding import (
2626
get_guided_decoding_logits_processor)
2727
from vllm.model_executor.layers.sampler import SamplerOutput
28-
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
28+
from vllm.outputs import PoolingRequestOutput, RequestOutput
2929
from vllm.pooling_params import PoolingParams
3030
from vllm.prompt_adapter.request import PromptAdapterRequest
3131
from vllm.sampling_params import SamplingParams
@@ -74,7 +74,7 @@ def _log_task_completion(task: asyncio.Task,
7474

7575

7676
class AsyncStream:
77-
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
77+
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
7878
that can be iterated over asynchronously via an async generator."""
7979

8080
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
@@ -83,7 +83,7 @@ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
8383
self._queue: asyncio.Queue = asyncio.Queue()
8484
self._finished = False
8585

86-
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
86+
def put(self, item: Union[RequestOutput, PoolingRequestOutput,
8787
Exception]) -> None:
8888
if not self._finished:
8989
self._queue.put_nowait(item)
@@ -103,7 +103,7 @@ def finished(self) -> bool:
103103

104104
async def generator(
105105
self
106-
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
106+
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
107107
try:
108108
while True:
109109
result = await self._queue.get()
@@ -154,7 +154,7 @@ def propagate_exception(self,
154154

155155
def process_request_output(self,
156156
request_output: Union[RequestOutput,
157-
EmbeddingRequestOutput],
157+
PoolingRequestOutput],
158158
*,
159159
verbose: bool = False) -> None:
160160
"""Process a request output from the engine."""
@@ -265,7 +265,7 @@ def __init__(self, *args, **kwargs):
265265

266266
async def step_async(
267267
self, virtual_engine: int
268-
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
268+
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
269269
"""Performs one decoding iteration and returns newly generated results.
270270
The workers are ran asynchronously if possible.
271271
@@ -907,7 +907,7 @@ def add_request(
907907
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
908908
priority: int = 0,
909909
) -> Coroutine[None, None, AsyncGenerator[Union[
910-
RequestOutput, EmbeddingRequestOutput], None]]:
910+
RequestOutput, PoolingRequestOutput], None]]:
911911
...
912912

913913
@overload
@@ -922,7 +922,7 @@ def add_request(
922922
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
923923
priority: int = 0,
924924
) -> Coroutine[None, None, AsyncGenerator[Union[
925-
RequestOutput, EmbeddingRequestOutput], None]]:
925+
RequestOutput, PoolingRequestOutput], None]]:
926926
...
927927

928928
@deprecate_kwargs(
@@ -941,7 +941,7 @@ async def add_request(
941941
priority: int = 0,
942942
*,
943943
inputs: Optional[PromptType] = None, # DEPRECATED
944-
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
944+
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
945945
if inputs is not None:
946946
prompt = inputs
947947
assert prompt is not None and params is not None
@@ -1070,7 +1070,7 @@ async def encode(
10701070
lora_request: Optional[LoRARequest] = None,
10711071
trace_headers: Optional[Mapping[str, str]] = None,
10721072
priority: int = 0,
1073-
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
1073+
) -> AsyncGenerator[PoolingRequestOutput, None]:
10741074
"""Generate outputs for a request from an embedding model.
10751075
10761076
Generate outputs for a request. This method is a coroutine. It adds the
@@ -1088,7 +1088,7 @@ async def encode(
10881088
Only applicable with priority scheduling.
10891089
10901090
Yields:
1091-
The output `EmbeddingRequestOutput` objects from the LLMEngine
1091+
The output `PoolingRequestOutput` objects from the LLMEngine
10921092
for the request.
10931093
10941094
Details:
@@ -1141,7 +1141,7 @@ async def encode(
11411141
trace_headers=trace_headers,
11421142
priority=priority,
11431143
):
1144-
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
1144+
yield LLMEngine.validate_output(output, PoolingRequestOutput)
11451145

11461146
async def abort(self, request_id: str) -> None:
11471147
"""Abort a request.

vllm/engine/llm_engine.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
get_local_guided_decoding_logits_processor)
4141
from vllm.model_executor.layers.sampler import SamplerOutput
4242
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
43-
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
43+
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
4444
RequestOutputFactory)
4545
from vllm.pooling_params import PoolingParams
4646
from vllm.prompt_adapter.request import PromptAdapterRequest
@@ -80,7 +80,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
8080

8181

8282
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
83-
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
83+
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
8484

8585

8686
@dataclass
@@ -112,7 +112,7 @@ class SchedulerContext:
112112
def __init__(self, multi_step_stream_outputs: bool = False):
113113
self.output_queue: Deque[OutputData] = deque()
114114
self.request_outputs: List[Union[RequestOutput,
115-
EmbeddingRequestOutput]] = []
115+
PoolingRequestOutput]] = []
116116
self.seq_group_metadata_list: Optional[
117117
List[SequenceGroupMetadata]] = None
118118
self.scheduler_outputs: Optional[SchedulerOutputs] = None
@@ -1314,7 +1314,7 @@ def _advance_to_next_step(
13141314
else:
13151315
seq.append_token_id(sample.output_token, sample.logprobs)
13161316

1317-
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
1317+
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
13181318
"""Performs one decoding iteration and returns newly generated results.
13191319
13201320
.. figure:: https://i.imgur.com/sv2HssD.png

vllm/engine/multiprocessing/client.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from vllm.logger import init_logger
3636
from vllm.lora.request import LoRARequest
3737
from vllm.model_executor.layers.sampler import SamplerOutput
38-
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
38+
from vllm.outputs import PoolingRequestOutput, RequestOutput
3939
from vllm.prompt_adapter.request import PromptAdapterRequest
4040
from vllm.sampling_params import SamplingParams
4141
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
@@ -495,7 +495,7 @@ def encode(
495495
lora_request: Optional[LoRARequest] = None,
496496
trace_headers: Optional[Mapping[str, str]] = None,
497497
priority: int = 0,
498-
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
498+
) -> AsyncGenerator[PoolingRequestOutput, None]:
499499
...
500500

501501
@overload
@@ -507,7 +507,7 @@ def encode(
507507
lora_request: Optional[LoRARequest] = None,
508508
trace_headers: Optional[Mapping[str, str]] = None,
509509
priority: int = 0,
510-
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
510+
) -> AsyncGenerator[PoolingRequestOutput, None]:
511511
...
512512

513513
@deprecate_kwargs(
@@ -524,7 +524,7 @@ def encode(
524524
priority: int = 0,
525525
*,
526526
inputs: Optional[PromptType] = None # DEPRECATED
527-
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
527+
) -> AsyncGenerator[PoolingRequestOutput, None]:
528528
"""Generate outputs for a request from an embedding model.
529529
530530
Generate outputs for a request. This method is a coroutine. It adds the
@@ -540,7 +540,7 @@ def encode(
540540
trace_headers: OpenTelemetry trace headers.
541541
542542
Yields:
543-
The output `EmbeddingRequestOutput` objects from the LLMEngine
543+
The output `PoolingRequestOutput` objects from the LLMEngine
544544
for the request.
545545
"""
546546
if inputs is not None:
@@ -549,7 +549,7 @@ def encode(
549549
and request_id is not None)
550550

551551
return cast(
552-
AsyncGenerator[EmbeddingRequestOutput, None],
552+
AsyncGenerator[PoolingRequestOutput, None],
553553
self._process_request(prompt,
554554
pooling_params,
555555
request_id,
@@ -567,7 +567,7 @@ async def _process_request(
567567
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
568568
priority: int = 0,
569569
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
570-
EmbeddingRequestOutput, None]]:
570+
PoolingRequestOutput, None]]:
571571
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
572572

573573
# If already dead, error out.

vllm/engine/protocol.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
from vllm.logger import init_logger
1212
from vllm.lora.request import LoRARequest
1313
from vllm.model_executor.layers.sampler import SamplerOutput
14-
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
15-
RequestOutput)
14+
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
1615
from vllm.pooling_params import PoolingParams
1716
from vllm.prompt_adapter.request import PromptAdapterRequest
1817
from vllm.sampling_params import BeamSearchParams, SamplingParams
@@ -209,7 +208,7 @@ def encode(
209208
lora_request: Optional[LoRARequest] = None,
210209
trace_headers: Optional[Mapping[str, str]] = None,
211210
priority: int = 0,
212-
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
211+
) -> AsyncGenerator[PoolingRequestOutput, None]:
213212
"""Generate outputs for a request from an embedding model."""
214213
...
215214

0 commit comments

Comments
 (0)