Skip to content

Commit 057b6b1

Browse files
committed
benchmark_serving support --served-model-name param
Signed-off-by: zibai <[email protected]>
1 parent f8ef146 commit 057b6b1

File tree

2 files changed

+34
-17
lines changed

2 files changed

+34
-17
lines changed

benchmarks/backend_request_func.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class RequestFuncInput:
2222
prompt_len: int
2323
output_len: int
2424
model: str
25+
model_name: str = None
2526
best_of: int = 1
2627
logprobs: Optional[int] = None
2728
extra_body: Optional[dict] = None
@@ -43,8 +44,8 @@ class RequestFuncOutput:
4344

4445

4546
async def async_request_tgi(
46-
request_func_input: RequestFuncInput,
47-
pbar: Optional[tqdm] = None,
47+
request_func_input: RequestFuncInput,
48+
pbar: Optional[tqdm] = None,
4849
) -> RequestFuncOutput:
4950
api_url = request_func_input.api_url
5051
assert api_url.endswith("generate_stream")
@@ -78,7 +79,7 @@ async def async_request_tgi(
7879
continue
7980
chunk_bytes = chunk_bytes.decode("utf-8")
8081

81-
#NOTE: Sometimes TGI returns a ping response without
82+
# NOTE: Sometimes TGI returns a ping response without
8283
# any data, we should skip it.
8384
if chunk_bytes.startswith(":"):
8485
continue
@@ -115,8 +116,8 @@ async def async_request_tgi(
115116

116117

117118
async def async_request_trt_llm(
118-
request_func_input: RequestFuncInput,
119-
pbar: Optional[tqdm] = None,
119+
request_func_input: RequestFuncInput,
120+
pbar: Optional[tqdm] = None,
120121
) -> RequestFuncOutput:
121122
api_url = request_func_input.api_url
122123
assert api_url.endswith("generate_stream")
@@ -182,8 +183,8 @@ async def async_request_trt_llm(
182183

183184

184185
async def async_request_deepspeed_mii(
185-
request_func_input: RequestFuncInput,
186-
pbar: Optional[tqdm] = None,
186+
request_func_input: RequestFuncInput,
187+
pbar: Optional[tqdm] = None,
187188
) -> RequestFuncOutput:
188189
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
189190
assert request_func_input.best_of == 1
@@ -225,8 +226,8 @@ async def async_request_deepspeed_mii(
225226

226227

227228
async def async_request_openai_completions(
228-
request_func_input: RequestFuncInput,
229-
pbar: Optional[tqdm] = None,
229+
request_func_input: RequestFuncInput,
230+
pbar: Optional[tqdm] = None,
230231
) -> RequestFuncOutput:
231232
api_url = request_func_input.api_url
232233
assert api_url.endswith(
@@ -235,7 +236,8 @@ async def async_request_openai_completions(
235236

236237
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
237238
payload = {
238-
"model": request_func_input.model,
239+
"model": request_func_input.model_name \
240+
if request_func_input.model_name else request_func_input.model,
239241
"prompt": request_func_input.prompt,
240242
"temperature": 0.0,
241243
"best_of": request_func_input.best_of,
@@ -315,8 +317,8 @@ async def async_request_openai_completions(
315317

316318

317319
async def async_request_openai_chat_completions(
318-
request_func_input: RequestFuncInput,
319-
pbar: Optional[tqdm] = None,
320+
request_func_input: RequestFuncInput,
321+
pbar: Optional[tqdm] = None,
320322
) -> RequestFuncOutput:
321323
api_url = request_func_input.api_url
322324
assert api_url.endswith(
@@ -328,7 +330,8 @@ async def async_request_openai_chat_completions(
328330
if request_func_input.multi_modal_content:
329331
content.append(request_func_input.multi_modal_content)
330332
payload = {
331-
"model": request_func_input.model,
333+
"model": request_func_input.model_name \
334+
if request_func_input.model_name else request_func_input.model,
332335
"messages": [
333336
{
334337
"role": "user",
@@ -417,10 +420,10 @@ def get_model(pretrained_model_name_or_path: str) -> str:
417420

418421

419422
def get_tokenizer(
420-
pretrained_model_name_or_path: str,
421-
tokenizer_mode: str = "auto",
422-
trust_remote_code: bool = False,
423-
**kwargs,
423+
pretrained_model_name_or_path: str,
424+
tokenizer_mode: str = "auto",
425+
trust_remote_code: bool = False,
426+
**kwargs,
424427
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
425428
if pretrained_model_name_or_path is not None and not os.path.exists(
426429
pretrained_model_name_or_path):

benchmarks/benchmark_serving.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@ async def benchmark(
525525
api_url: str,
526526
base_url: str,
527527
model_id: str,
528+
model_name: str,
528529
tokenizer: PreTrainedTokenizerBase,
529530
input_requests: List[Tuple[str, int, int]],
530531
logprobs: Optional[int],
@@ -553,6 +554,7 @@ async def benchmark(
553554
"Multi-modal content is only supported on 'openai-chat' backend.")
554555
test_input = RequestFuncInput(
555556
model=model_id,
557+
model_name=model_name,
556558
prompt=test_prompt,
557559
api_url=api_url,
558560
prompt_len=test_prompt_len,
@@ -573,6 +575,7 @@ async def benchmark(
573575
if profile:
574576
print("Starting profiler...")
575577
profile_input = RequestFuncInput(model=model_id,
578+
model_name=model_name,
576579
prompt=test_prompt,
577580
api_url=base_url + "/start_profile",
578581
prompt_len=test_prompt_len,
@@ -616,6 +619,7 @@ async def limited_request_func(request_func_input, pbar):
616619
async for request in get_request(input_requests, request_rate, burstiness):
617620
prompt, prompt_len, output_len, mm_content = request
618621
request_func_input = RequestFuncInput(model=model_id,
622+
model_name=model_name,
619623
prompt=prompt,
620624
api_url=api_url,
621625
prompt_len=prompt_len,
@@ -780,6 +784,7 @@ def main(args: argparse.Namespace):
780784

781785
backend = args.backend
782786
model_id = args.model
787+
model_name = args.served_model_name
783788
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
784789
tokenizer_mode = args.tokenizer_mode
785790

@@ -877,6 +882,7 @@ def main(args: argparse.Namespace):
877882
api_url=api_url,
878883
base_url=base_url,
879884
model_id=model_id,
885+
model_name=model_name,
880886
tokenizer=tokenizer,
881887
input_requests=input_requests,
882888
logprobs=args.logprobs,
@@ -1222,5 +1228,13 @@ def main(args: argparse.Namespace):
12221228
'always use the slow tokenizer. \n* '
12231229
'"mistral" will always use the `mistral_common` tokenizer.')
12241230

1231+
parser.add_argument(
1232+
"--served-model-name",
1233+
type=str,
1234+
default=None,
1235+
help="The model name used in the API. "
1236+
"If not specified, the model name will be the "
1237+
"same as the ``--model`` argument. ")
1238+
12251239
args = parser.parse_args()
12261240
main(args)

0 commit comments

Comments
 (0)