@@ -22,6 +22,7 @@ class RequestFuncInput:
22
22
prompt_len : int
23
23
output_len : int
24
24
model : str
25
+ model_name : str = None
25
26
best_of : int = 1
26
27
logprobs : Optional [int ] = None
27
28
extra_body : Optional [dict ] = None
@@ -43,8 +44,8 @@ class RequestFuncOutput:
43
44
44
45
45
46
async def async_request_tgi (
46
- request_func_input : RequestFuncInput ,
47
- pbar : Optional [tqdm ] = None ,
47
+ request_func_input : RequestFuncInput ,
48
+ pbar : Optional [tqdm ] = None ,
48
49
) -> RequestFuncOutput :
49
50
api_url = request_func_input .api_url
50
51
assert api_url .endswith ("generate_stream" )
@@ -78,7 +79,7 @@ async def async_request_tgi(
78
79
continue
79
80
chunk_bytes = chunk_bytes .decode ("utf-8" )
80
81
81
- #NOTE: Sometimes TGI returns a ping response without
82
+ # NOTE: Sometimes TGI returns a ping response without
82
83
# any data, we should skip it.
83
84
if chunk_bytes .startswith (":" ):
84
85
continue
@@ -115,8 +116,8 @@ async def async_request_tgi(
115
116
116
117
117
118
async def async_request_trt_llm (
118
- request_func_input : RequestFuncInput ,
119
- pbar : Optional [tqdm ] = None ,
119
+ request_func_input : RequestFuncInput ,
120
+ pbar : Optional [tqdm ] = None ,
120
121
) -> RequestFuncOutput :
121
122
api_url = request_func_input .api_url
122
123
assert api_url .endswith ("generate_stream" )
@@ -182,8 +183,8 @@ async def async_request_trt_llm(
182
183
183
184
184
185
async def async_request_deepspeed_mii (
185
- request_func_input : RequestFuncInput ,
186
- pbar : Optional [tqdm ] = None ,
186
+ request_func_input : RequestFuncInput ,
187
+ pbar : Optional [tqdm ] = None ,
187
188
) -> RequestFuncOutput :
188
189
async with aiohttp .ClientSession (timeout = AIOHTTP_TIMEOUT ) as session :
189
190
assert request_func_input .best_of == 1
@@ -225,8 +226,8 @@ async def async_request_deepspeed_mii(
225
226
226
227
227
228
async def async_request_openai_completions (
228
- request_func_input : RequestFuncInput ,
229
- pbar : Optional [tqdm ] = None ,
229
+ request_func_input : RequestFuncInput ,
230
+ pbar : Optional [tqdm ] = None ,
230
231
) -> RequestFuncOutput :
231
232
api_url = request_func_input .api_url
232
233
assert api_url .endswith (
@@ -235,7 +236,8 @@ async def async_request_openai_completions(
235
236
236
237
async with aiohttp .ClientSession (timeout = AIOHTTP_TIMEOUT ) as session :
237
238
payload = {
238
- "model" : request_func_input .model ,
239
+ "model" : request_func_input .model_name \
240
+ if request_func_input .model_name else request_func_input .model ,
239
241
"prompt" : request_func_input .prompt ,
240
242
"temperature" : 0.0 ,
241
243
"best_of" : request_func_input .best_of ,
@@ -315,8 +317,8 @@ async def async_request_openai_completions(
315
317
316
318
317
319
async def async_request_openai_chat_completions (
318
- request_func_input : RequestFuncInput ,
319
- pbar : Optional [tqdm ] = None ,
320
+ request_func_input : RequestFuncInput ,
321
+ pbar : Optional [tqdm ] = None ,
320
322
) -> RequestFuncOutput :
321
323
api_url = request_func_input .api_url
322
324
assert api_url .endswith (
@@ -328,7 +330,8 @@ async def async_request_openai_chat_completions(
328
330
if request_func_input .multi_modal_content :
329
331
content .append (request_func_input .multi_modal_content )
330
332
payload = {
331
- "model" : request_func_input .model ,
333
+ "model" : request_func_input .model_name \
334
+ if request_func_input .model_name else request_func_input .model ,
332
335
"messages" : [
333
336
{
334
337
"role" : "user" ,
@@ -417,10 +420,10 @@ def get_model(pretrained_model_name_or_path: str) -> str:
417
420
418
421
419
422
def get_tokenizer (
420
- pretrained_model_name_or_path : str ,
421
- tokenizer_mode : str = "auto" ,
422
- trust_remote_code : bool = False ,
423
- ** kwargs ,
423
+ pretrained_model_name_or_path : str ,
424
+ tokenizer_mode : str = "auto" ,
425
+ trust_remote_code : bool = False ,
426
+ ** kwargs ,
424
427
) -> Union [PreTrainedTokenizer , PreTrainedTokenizerFast ]:
425
428
if pretrained_model_name_or_path is not None and not os .path .exists (
426
429
pretrained_model_name_or_path ):
0 commit comments