13
13
from vllm import LLM , SamplingParams
14
14
from vllm .engine .arg_utils import EngineArgs
15
15
from vllm .inputs import PromptType
16
+ from vllm .sampling_params import BeamSearchParams
16
17
from vllm .utils import FlexibleArgumentParser
17
18
18
19
@@ -40,6 +41,20 @@ def main(args: argparse.Namespace):
40
41
"prompt_token_ids" : batch
41
42
} for batch in dummy_prompt_token_ids .tolist ()]
42
43
44
+ def llm_generate ():
45
+ if not args .use_beam_search :
46
+ llm .generate (dummy_prompts ,
47
+ sampling_params = sampling_params ,
48
+ use_tqdm = False )
49
+ else :
50
+ llm .beam_search (
51
+ dummy_prompts ,
52
+ BeamSearchParams (
53
+ beam_width = args .n ,
54
+ max_tokens = args .output_len ,
55
+ ignore_eos = True ,
56
+ ))
57
+
43
58
def run_to_completion (profile_dir : Optional [str ] = None ):
44
59
if profile_dir :
45
60
with torch .profiler .profile (
@@ -49,15 +64,11 @@ def run_to_completion(profile_dir: Optional[str] = None):
49
64
],
50
65
on_trace_ready = torch .profiler .tensorboard_trace_handler (
51
66
str (profile_dir ))) as p :
52
- llm .generate (dummy_prompts ,
53
- sampling_params = sampling_params ,
54
- use_tqdm = False )
67
+ llm_generate ()
55
68
print (p .key_averages ().table (sort_by = "self_cuda_time_total" ))
56
69
else :
57
70
start_time = time .perf_counter ()
58
- llm .generate (dummy_prompts ,
59
- sampling_params = sampling_params ,
60
- use_tqdm = False )
71
+ llm_generate ()
61
72
end_time = time .perf_counter ()
62
73
latency = end_time - start_time
63
74
return latency
0 commit comments