Skip to content

Commit a26ad39

Browse files
yeqcharlotteyeq
authored andcommitted
[Bugfix] fix beam search input errors and latency benchmark script (vllm-project#11875)
Signed-off-by: Ye Qi <[email protected]> Co-authored-by: yeq <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent cac0bb5 commit a26ad39

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

benchmarks/benchmark_latency.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from vllm import LLM, SamplingParams
1414
from vllm.engine.arg_utils import EngineArgs
1515
from vllm.inputs import PromptType
16+
from vllm.sampling_params import BeamSearchParams
1617
from vllm.utils import FlexibleArgumentParser
1718

1819

@@ -40,6 +41,20 @@ def main(args: argparse.Namespace):
4041
"prompt_token_ids": batch
4142
} for batch in dummy_prompt_token_ids.tolist()]
4243

44+
def llm_generate():
45+
if not args.use_beam_search:
46+
llm.generate(dummy_prompts,
47+
sampling_params=sampling_params,
48+
use_tqdm=False)
49+
else:
50+
llm.beam_search(
51+
dummy_prompts,
52+
BeamSearchParams(
53+
beam_width=args.n,
54+
max_tokens=args.output_len,
55+
ignore_eos=True,
56+
))
57+
4358
def run_to_completion(profile_dir: Optional[str] = None):
4459
if profile_dir:
4560
with torch.profiler.profile(
@@ -49,15 +64,11 @@ def run_to_completion(profile_dir: Optional[str] = None):
4964
],
5065
on_trace_ready=torch.profiler.tensorboard_trace_handler(
5166
str(profile_dir))) as p:
52-
llm.generate(dummy_prompts,
53-
sampling_params=sampling_params,
54-
use_tqdm=False)
67+
llm_generate()
5568
print(p.key_averages().table(sort_by="self_cuda_time_total"))
5669
else:
5770
start_time = time.perf_counter()
58-
llm.generate(dummy_prompts,
59-
sampling_params=sampling_params,
60-
use_tqdm=False)
71+
llm_generate()
6172
end_time = time.perf_counter()
6273
latency = end_time - start_time
6374
return latency

vllm/entrypoints/llm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
parse_chat_messages,
2222
resolve_chat_template_content_format)
2323
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
24-
from vllm.inputs.parse import parse_and_batch_prompt
24+
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
2525
from vllm.logger import init_logger
2626
from vllm.lora.request import LoRARequest
2727
from vllm.model_executor.guided_decoding.guided_fields import (
@@ -457,7 +457,7 @@ def generate(
457457

458458
def beam_search(
459459
self,
460-
prompts: List[Union[str, List[int]]],
460+
prompts: List[Union[TokensPrompt, TextPrompt]],
461461
params: BeamSearchParams,
462462
) -> List[BeamSearchOutput]:
463463
"""
@@ -493,8 +493,10 @@ def sort_beams_key(x: BeamSearchSequence) -> float:
493493
instances: List[BeamSearchInstance] = []
494494

495495
for prompt in prompts:
496-
prompt_tokens = prompt if isinstance(
497-
prompt, list) else tokenizer.encode(prompt)
496+
if is_token_prompt(prompt):
497+
prompt_tokens = prompt["prompt_token_ids"]
498+
else:
499+
prompt_tokens = tokenizer.encode(prompt["prompt"])
498500
instances.append(BeamSearchInstance(prompt_tokens))
499501

500502
for _ in range(max_tokens):

0 commit comments

Comments
 (0)