Skip to content

Commit 60cf35b

Browse files
aandywDarkLight1337ywang96
authored andcommitted
[Frontend] Batch inference for llm.chat() API (vllm-project#8648)
Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: Roger Wang <[email protected]> Co-authored-by: Roger Wang <[email protected]> Signed-off-by: Alvant <[email protected]>
1 parent 2d986fc commit 60cf35b

File tree

3 files changed

+111
-33
lines changed

3 files changed

+111
-33
lines changed

examples/offline_inference_chat.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,33 @@ def print_outputs(outputs):
3939
use_tqdm=False)
4040
print_outputs(outputs)
4141

42+
# You can run batch inference with llm.chat API
43+
conversation = [
44+
{
45+
"role": "system",
46+
"content": "You are a helpful assistant"
47+
},
48+
{
49+
"role": "user",
50+
"content": "Hello"
51+
},
52+
{
53+
"role": "assistant",
54+
"content": "Hello! How can I assist you today?"
55+
},
56+
{
57+
"role": "user",
58+
"content": "Write an essay about the importance of higher education.",
59+
},
60+
]
61+
conversations = [conversation for _ in range(10)]
62+
63+
# We turn on tqdm progress bar to verify it's indeed running batch inference
64+
outputs = llm.chat(messages=conversations,
65+
sampling_params=sampling_params,
66+
use_tqdm=True)
67+
print_outputs(outputs)
68+
4269
# A chat template can be optionally supplied.
4370
# If not, the model will use its default chat template.
4471

tests/entrypoints/llm/test_generate.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,41 @@ def test_chat():
162162
assert len(outputs) == 1
163163

164164

165+
def test_multi_chat():
166+
167+
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
168+
169+
prompt1 = "Explain the concept of entropy."
170+
prompt2 = "Explain what among us is."
171+
172+
conversation1 = [
173+
{
174+
"role": "system",
175+
"content": "You are a helpful assistant"
176+
},
177+
{
178+
"role": "user",
179+
"content": prompt1
180+
},
181+
]
182+
183+
conversation2 = [
184+
{
185+
"role": "system",
186+
"content": "You are a helpful assistant"
187+
},
188+
{
189+
"role": "user",
190+
"content": prompt2
191+
},
192+
]
193+
194+
messages = [conversation1, conversation2]
195+
196+
outputs = llm.chat(messages)
197+
assert len(outputs) == 2
198+
199+
165200
@pytest.mark.parametrize("image_urls",
166201
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
167202
def test_chat_multi_image(image_urls: List[str]):

vllm/entrypoints/llm.py

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,8 @@ def beam_search(
485485

486486
def chat(
487487
self,
488-
messages: List[ChatCompletionMessageParam],
488+
messages: Union[List[ChatCompletionMessageParam],
489+
List[List[ChatCompletionMessageParam]]],
489490
sampling_params: Optional[Union[SamplingParams,
490491
List[SamplingParams]]] = None,
491492
use_tqdm: bool = True,
@@ -505,8 +506,9 @@ def chat(
505506
to the OpenAI API.
506507
507508
Args:
508-
messages: A single conversation represented as a list of messages.
509-
Each message is a dictionary with 'role' and 'content' keys.
509+
messages: A list of conversations or a single conversation.
510+
- Each conversation is represented as a list of messages.
511+
- Each message is a dictionary with 'role' and 'content' keys.
510512
sampling_params: The sampling parameters for text generation.
511513
If None, we use the default sampling parameters. When it
512514
is a single value, it is applied to every prompt. When it
@@ -523,42 +525,56 @@ def chat(
523525
A list of ``RequestOutput`` objects containing the generated
524526
responses in the same order as the input messages.
525527
"""
528+
list_of_messages: List[List[ChatCompletionMessageParam]]
526529

527-
tokenizer = self.get_tokenizer()
528-
model_config = self.llm_engine.get_model_config()
529-
530-
conversation, mm_data = parse_chat_messages(messages, model_config,
531-
tokenizer)
532-
533-
prompt: Union[str, List[int]]
534-
if isinstance(tokenizer, MistralTokenizer):
535-
prompt = apply_mistral_chat_template(
536-
tokenizer,
537-
messages=messages,
538-
chat_template=chat_template,
539-
add_generation_prompt=add_generation_prompt,
540-
tools=tools,
541-
)
530+
# Handle multi and single conversations
531+
if is_list_of(messages, list):
532+
# messages is List[List[...]]
533+
list_of_messages = messages
542534
else:
543-
prompt = apply_hf_chat_template(
544-
tokenizer,
545-
conversation=conversation,
546-
chat_template=chat_template,
547-
add_generation_prompt=add_generation_prompt,
548-
tools=tools,
549-
)
535+
# messages is List[...]
536+
list_of_messages = [messages]
537+
538+
prompts: List[Union[TokensPrompt, TextPrompt]] = []
539+
540+
for msgs in list_of_messages:
541+
tokenizer = self.get_tokenizer()
542+
model_config = self.llm_engine.get_model_config()
543+
544+
conversation, mm_data = parse_chat_messages(
545+
msgs, model_config, tokenizer)
546+
547+
prompt_data: Union[str, List[int]]
548+
if isinstance(tokenizer, MistralTokenizer):
549+
prompt_data = apply_mistral_chat_template(
550+
tokenizer,
551+
messages=msgs,
552+
chat_template=chat_template,
553+
add_generation_prompt=add_generation_prompt,
554+
tools=tools,
555+
)
556+
else:
557+
prompt_data = apply_hf_chat_template(
558+
tokenizer,
559+
conversation=conversation,
560+
chat_template=chat_template,
561+
add_generation_prompt=add_generation_prompt,
562+
tools=tools,
563+
)
564+
565+
prompt: Union[TokensPrompt, TextPrompt]
566+
if is_list_of(prompt_data, int):
567+
prompt = TokensPrompt(prompt_token_ids=prompt_data)
568+
else:
569+
prompt = TextPrompt(prompt=prompt_data)
550570

551-
inputs: PromptInputs
552-
if is_list_of(prompt, int):
553-
inputs = TokensPrompt(prompt_token_ids=prompt)
554-
else:
555-
inputs = TextPrompt(prompt=prompt)
571+
if mm_data is not None:
572+
prompt["multi_modal_data"] = mm_data
556573

557-
if mm_data is not None:
558-
inputs["multi_modal_data"] = mm_data
574+
prompts.append(prompt)
559575

560576
return self.generate(
561-
inputs,
577+
prompts,
562578
sampling_params=sampling_params,
563579
use_tqdm=use_tqdm,
564580
lora_request=lora_request,

0 commit comments

Comments
 (0)