Skip to content

Commit dee692a

Browse files
committed
compability with basic_api, change api path to /extra
1 parent b4e9e18 commit dee692a

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

koboldcpp.py

+26-10
Original file line numberDiff line numberDiff line change
@@ -226,14 +226,30 @@ def __init__(self, addr, port, embedded_kailite):
226226
self.embedded_kailite = embedded_kailite
227227

228228

229-
async def generate_text(self, newprompt, genparams):
229+
async def generate_text(self, newprompt, genparams, basic_api_flag):
230230
loop = asyncio.get_event_loop()
231231
executor = ThreadPoolExecutor()
232232

233233
def run_blocking():
234234
# Reset finished status before generating
235235
handle.bind_set_stream_finished(False)
236236

237+
if basic_api_flag:
238+
return generate(
239+
prompt=newprompt,
240+
max_length=genparams.get('max', 50),
241+
temperature=genparams.get('temperature', 0.8),
242+
top_k=int(genparams.get('top_k', 120)),
243+
top_a=genparams.get('top_a', 0.0),
244+
top_p=genparams.get('top_p', 0.85),
245+
typical_p=genparams.get('typical', 1.0),
246+
tfs=genparams.get('tfs', 1.0),
247+
rep_pen=genparams.get('rep_pen', 1.1),
248+
rep_pen_range=genparams.get('rep_pen_range', 128),
249+
seed=genparams.get('sampler_seed', -1),
250+
stop_sequence=genparams.get('stop_sequence', [])
251+
)
252+
237253
return generate(prompt=newprompt,
238254
max_context_length=genparams.get('max_context_length', maxctx),
239255
max_length=genparams.get('max_length', 50),
@@ -251,7 +267,9 @@ def run_blocking():
251267

252268
recvtxt = await loop.run_in_executor(executor, run_blocking)
253269

254-
res = {"results": [{"text": recvtxt}]}
270+
utfprint("\nOutput: " + recvtxt)
271+
272+
res = {"data": {"seqs":[recvtxt]}} if basic_api_flag else {"results": [{"text": recvtxt}]}
255273

256274
try:
257275
return res
@@ -279,20 +297,19 @@ async def handle_sse_stream(self, request):
279297
await response.write_eof()
280298
await response.force_close()
281299

282-
async def handle_request(self, request, genparams, newprompt, stream_flag):
300+
async def handle_request(self, request, genparams, newprompt, basic_api_flag, stream_flag):
283301
tasks = []
284302

285303
if stream_flag:
286304
tasks.append(self.handle_sse_stream(request,))
287305

288-
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams))
306+
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag))
289307
tasks.append(generate_task)
290308

291309
try:
292310
await asyncio.gather(*tasks)
293-
if not stream_flag:
294-
generate_result = generate_task.result()
295-
return generate_result
311+
generate_result = generate_task.result()
312+
return generate_result
296313
except Exception as e:
297314
print(e)
298315

@@ -344,7 +361,6 @@ async def handle_post(self, request):
344361
kai_api_flag = False
345362
kai_sse_stream_flag = False
346363
path = request.path.rstrip('/')
347-
print(request)
348364

349365
if modelbusy:
350366
return web.json_response(
@@ -358,7 +374,7 @@ async def handle_post(self, request):
358374
if path.endswith(('/api/v1/generate', '/api/latest/generate')):
359375
kai_api_flag = True
360376

361-
if path.endswith('/api/v1/generate/stream'):
377+
if path.endswith('/api/extra/generate/stream'):
362378
kai_api_flag = True
363379
kai_sse_stream_flag = True
364380

@@ -378,7 +394,7 @@ async def handle_post(self, request):
378394
fullprompt = genparams.get('text', "")
379395
newprompt = fullprompt
380396

381-
gen = await self.handle_request(request, genparams, newprompt, kai_sse_stream_flag)
397+
gen = await self.handle_request(request, genparams, newprompt, basic_api_flag, kai_sse_stream_flag)
382398

383399
modelbusy = False
384400

0 commit comments

Comments
 (0)