@@ -485,7 +485,8 @@ def beam_search(
485
485
486
486
def chat (
487
487
self ,
488
- messages : List [ChatCompletionMessageParam ],
488
+ messages : Union [List [ChatCompletionMessageParam ],
489
+ List [List [ChatCompletionMessageParam ]]],
489
490
sampling_params : Optional [Union [SamplingParams ,
490
491
List [SamplingParams ]]] = None ,
491
492
use_tqdm : bool = True ,
@@ -505,8 +506,9 @@ def chat(
505
506
to the OpenAI API.
506
507
507
508
Args:
508
- messages: A single conversation represented as a list of messages.
509
- Each message is a dictionary with 'role' and 'content' keys.
509
+ messages: A list of conversations or a single conversation.
510
+ - Each conversation is represented as a list of messages.
511
+ - Each message is a dictionary with 'role' and 'content' keys.
510
512
sampling_params: The sampling parameters for text generation.
511
513
If None, we use the default sampling parameters. When it
512
514
is a single value, it is applied to every prompt. When it
@@ -523,42 +525,56 @@ def chat(
523
525
A list of ``RequestOutput`` objects containing the generated
524
526
responses in the same order as the input messages.
525
527
"""
528
+ list_of_messages : List [List [ChatCompletionMessageParam ]]
526
529
527
- tokenizer = self .get_tokenizer ()
528
- model_config = self .llm_engine .get_model_config ()
529
-
530
- conversation , mm_data = parse_chat_messages (messages , model_config ,
531
- tokenizer )
532
-
533
- prompt : Union [str , List [int ]]
534
- if isinstance (tokenizer , MistralTokenizer ):
535
- prompt = apply_mistral_chat_template (
536
- tokenizer ,
537
- messages = messages ,
538
- chat_template = chat_template ,
539
- add_generation_prompt = add_generation_prompt ,
540
- tools = tools ,
541
- )
530
+ # Handle multi and single conversations
531
+ if is_list_of (messages , list ):
532
+ # messages is List[List[...]]
533
+ list_of_messages = messages
542
534
else :
543
- prompt = apply_hf_chat_template (
544
- tokenizer ,
545
- conversation = conversation ,
546
- chat_template = chat_template ,
547
- add_generation_prompt = add_generation_prompt ,
548
- tools = tools ,
549
- )
535
+ # messages is List[...]
536
+ list_of_messages = [messages ]
537
+
538
+ prompts : List [Union [TokensPrompt , TextPrompt ]] = []
539
+
540
+ for msgs in list_of_messages :
541
+ tokenizer = self .get_tokenizer ()
542
+ model_config = self .llm_engine .get_model_config ()
543
+
544
+ conversation , mm_data = parse_chat_messages (
545
+ msgs , model_config , tokenizer )
546
+
547
+ prompt_data : Union [str , List [int ]]
548
+ if isinstance (tokenizer , MistralTokenizer ):
549
+ prompt_data = apply_mistral_chat_template (
550
+ tokenizer ,
551
+ messages = msgs ,
552
+ chat_template = chat_template ,
553
+ add_generation_prompt = add_generation_prompt ,
554
+ tools = tools ,
555
+ )
556
+ else :
557
+ prompt_data = apply_hf_chat_template (
558
+ tokenizer ,
559
+ conversation = conversation ,
560
+ chat_template = chat_template ,
561
+ add_generation_prompt = add_generation_prompt ,
562
+ tools = tools ,
563
+ )
564
+
565
+ prompt : Union [TokensPrompt , TextPrompt ]
566
+ if is_list_of (prompt_data , int ):
567
+ prompt = TokensPrompt (prompt_token_ids = prompt_data )
568
+ else :
569
+ prompt = TextPrompt (prompt = prompt_data )
550
570
551
- inputs : PromptInputs
552
- if is_list_of (prompt , int ):
553
- inputs = TokensPrompt (prompt_token_ids = prompt )
554
- else :
555
- inputs = TextPrompt (prompt = prompt )
571
+ if mm_data is not None :
572
+ prompt ["multi_modal_data" ] = mm_data
556
573
557
- if mm_data is not None :
558
- inputs ["multi_modal_data" ] = mm_data
574
+ prompts .append (prompt )
559
575
560
576
return self .generate (
561
- inputs ,
577
+ prompts ,
562
578
sampling_params = sampling_params ,
563
579
use_tqdm = use_tqdm ,
564
580
lora_request = lora_request ,
0 commit comments