34
34
from typing import Any , Optional
35
35
36
36
import numpy as np
37
- from backend_request_func import (ASYNC_REQUEST_FUNCS , RequestFuncInput ,
37
+ from backend_request_func import (ASYNC_REQUEST_FUNCS ,
38
+ OPENAI_COMPATIBLE_BACKENDS , RequestFuncInput ,
38
39
RequestFuncOutput )
39
40
from tqdm .asyncio import tqdm
40
41
from transformers import PreTrainedTokenizerBase
@@ -260,6 +261,7 @@ async def benchmark(
260
261
goodput_config_dict : dict [str , float ],
261
262
max_concurrency : Optional [int ],
262
263
lora_modules : Optional [Iterable [str ]],
264
+ extra_body : Optional [dict ],
263
265
):
264
266
if backend in ASYNC_REQUEST_FUNCS :
265
267
request_func = ASYNC_REQUEST_FUNCS [backend ]
@@ -287,6 +289,7 @@ async def benchmark(
287
289
logprobs = logprobs ,
288
290
multi_modal_content = test_mm_content ,
289
291
ignore_eos = ignore_eos ,
292
+ extra_body = extra_body ,
290
293
)
291
294
292
295
test_output = await request_func (request_func_input = test_input )
@@ -313,7 +316,8 @@ async def benchmark(
313
316
output_len = test_output_len ,
314
317
logprobs = logprobs ,
315
318
multi_modal_content = test_mm_content ,
316
- ignore_eos = ignore_eos )
319
+ ignore_eos = ignore_eos ,
320
+ extra_body = extra_body )
317
321
profile_output = await request_func (request_func_input = profile_input )
318
322
if profile_output .success :
319
323
print ("Profiler started" )
@@ -363,7 +367,8 @@ async def limited_request_func(request_func_input, pbar):
363
367
output_len = output_len ,
364
368
logprobs = logprobs ,
365
369
multi_modal_content = mm_content ,
366
- ignore_eos = ignore_eos )
370
+ ignore_eos = ignore_eos ,
371
+ extra_body = extra_body )
367
372
tasks .append (
368
373
asyncio .create_task (
369
374
limited_request_func (request_func_input = request_func_input ,
@@ -652,6 +657,26 @@ def main(args: argparse.Namespace):
652
657
raise ValueError (f"Unknown dataset: { args .dataset_name } " ) from err
653
658
goodput_config_dict = check_goodput_args (args )
654
659
660
+ # Collect the sampling parameters.
661
+ sampling_params = {
662
+ k : v
663
+ for k , v in {
664
+ "top_p" : args .top_p ,
665
+ "top_k" : args .top_k ,
666
+ "min_p" : args .min_p ,
667
+ "temperature" : args .temperature
668
+ }.items () if v is not None
669
+ }
670
+
671
+ # Sampling parameters are only supported by openai-compatible backend.
672
+ if sampling_params and args .backend not in OPENAI_COMPATIBLE_BACKENDS :
673
+ raise ValueError (
674
+ "Sampling parameters are only supported by openai-compatible "
675
+ "backends." )
676
+
677
+ if "temperature" not in sampling_params :
678
+ sampling_params ["temperature" ] = 0.0 # Default to greedy decoding.
679
+
655
680
# Avoid GC processing "static" data - reduce pause times.
656
681
gc .collect ()
657
682
gc .freeze ()
@@ -678,6 +703,7 @@ def main(args: argparse.Namespace):
678
703
goodput_config_dict = goodput_config_dict ,
679
704
max_concurrency = args .max_concurrency ,
680
705
lora_modules = args .lora_modules ,
706
+ extra_body = sampling_params ,
681
707
))
682
708
683
709
# Save config and results to json
@@ -1000,6 +1026,33 @@ def main(args: argparse.Namespace):
1000
1026
"from the sampled HF dataset." ,
1001
1027
)
1002
1028
1029
+ sampling_group = parser .add_argument_group ("sampling parameters" )
1030
+ sampling_group .add_argument (
1031
+ "--top-p" ,
1032
+ type = float ,
1033
+ default = None ,
1034
+ help = "Top-p sampling parameter. Only has effect on openai-compatible "
1035
+ "backends." )
1036
+ sampling_group .add_argument (
1037
+ "--top-k" ,
1038
+ type = int ,
1039
+ default = None ,
1040
+ help = "Top-k sampling parameter. Only has effect on openai-compatible "
1041
+ "backends." )
1042
+ sampling_group .add_argument (
1043
+ "--min-p" ,
1044
+ type = float ,
1045
+ default = None ,
1046
+ help = "Min-p sampling parameter. Only has effect on openai-compatible "
1047
+ "backends." )
1048
+ sampling_group .add_argument (
1049
+ "--temperature" ,
1050
+ type = float ,
1051
+ default = None ,
1052
+ help = "Temperature sampling parameter. Only has effect on "
1053
+ "openai-compatible backends. If not specified, default to greedy "
1054
+ "decoding (i.e. temperature==0.0)." )
1055
+
1003
1056
parser .add_argument (
1004
1057
'--tokenizer-mode' ,
1005
1058
type = str ,
0 commit comments