@@ -226,14 +226,30 @@ def __init__(self, addr, port, embedded_kailite):
226
226
self .embedded_kailite = embedded_kailite
227
227
228
228
229
- async def generate_text (self , newprompt , genparams ):
229
+ async def generate_text (self , newprompt , genparams , basic_api_flag ):
230
230
loop = asyncio .get_event_loop ()
231
231
executor = ThreadPoolExecutor ()
232
232
233
233
def run_blocking ():
234
234
# Reset finished status before generating
235
235
handle .bind_set_stream_finished (False )
236
236
237
+ if basic_api_flag :
238
+ return generate (
239
+ prompt = newprompt ,
240
+ max_length = genparams .get ('max' , 50 ),
241
+ temperature = genparams .get ('temperature' , 0.8 ),
242
+ top_k = int (genparams .get ('top_k' , 120 )),
243
+ top_a = genparams .get ('top_a' , 0.0 ),
244
+ top_p = genparams .get ('top_p' , 0.85 ),
245
+ typical_p = genparams .get ('typical' , 1.0 ),
246
+ tfs = genparams .get ('tfs' , 1.0 ),
247
+ rep_pen = genparams .get ('rep_pen' , 1.1 ),
248
+ rep_pen_range = genparams .get ('rep_pen_range' , 128 ),
249
+ seed = genparams .get ('sampler_seed' , - 1 ),
250
+ stop_sequence = genparams .get ('stop_sequence' , [])
251
+ )
252
+
237
253
return generate (prompt = newprompt ,
238
254
max_context_length = genparams .get ('max_context_length' , maxctx ),
239
255
max_length = genparams .get ('max_length' , 50 ),
@@ -251,7 +267,9 @@ def run_blocking():
251
267
252
268
recvtxt = await loop .run_in_executor (executor , run_blocking )
253
269
254
- res = {"results" : [{"text" : recvtxt }]}
270
+ utfprint ("\n Output: " + recvtxt )
271
+
272
+ res = {"data" : {"seqs" :[recvtxt ]}} if basic_api_flag else {"results" : [{"text" : recvtxt }]}
255
273
256
274
try :
257
275
return res
@@ -279,20 +297,19 @@ async def handle_sse_stream(self, request):
279
297
await response .write_eof ()
280
298
await response .force_close ()
281
299
282
- async def handle_request (self , request , genparams , newprompt , stream_flag ):
300
+ async def handle_request (self , request , genparams , newprompt , basic_api_flag , stream_flag ):
283
301
tasks = []
284
302
285
303
if stream_flag :
286
304
tasks .append (self .handle_sse_stream (request ,))
287
305
288
- generate_task = asyncio .create_task (self .generate_text (newprompt , genparams ))
306
+ generate_task = asyncio .create_task (self .generate_text (newprompt , genparams , basic_api_flag ))
289
307
tasks .append (generate_task )
290
308
291
309
try :
292
310
await asyncio .gather (* tasks )
293
- if not stream_flag :
294
- generate_result = generate_task .result ()
295
- return generate_result
311
+ generate_result = generate_task .result ()
312
+ return generate_result
296
313
except Exception as e :
297
314
print (e )
298
315
@@ -344,7 +361,6 @@ async def handle_post(self, request):
344
361
kai_api_flag = False
345
362
kai_sse_stream_flag = False
346
363
path = request .path .rstrip ('/' )
347
- print (request )
348
364
349
365
if modelbusy :
350
366
return web .json_response (
@@ -358,7 +374,7 @@ async def handle_post(self, request):
358
374
if path .endswith (('/api/v1/generate' , '/api/latest/generate' )):
359
375
kai_api_flag = True
360
376
361
- if path .endswith ('/api/v1 /generate/stream' ):
377
+ if path .endswith ('/api/extra /generate/stream' ):
362
378
kai_api_flag = True
363
379
kai_sse_stream_flag = True
364
380
@@ -378,7 +394,7 @@ async def handle_post(self, request):
378
394
fullprompt = genparams .get ('text' , "" )
379
395
newprompt = fullprompt
380
396
381
- gen = await self .handle_request (request , genparams , newprompt , kai_sse_stream_flag )
397
+ gen = await self .handle_request (request , genparams , newprompt , basic_api_flag , kai_sse_stream_flag )
382
398
383
399
modelbusy = False
384
400
0 commit comments