Skip to content

Commit b4e9e18

Browse files
committed
fix legacy streaming
1 parent 9a8da35 commit b4e9e18

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

koboldcpp.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import ctypes
66
import os
77
import argparse
8-
import json, sys, time, asyncio
8+
import json, sys, time, asyncio, socket
99
from aiohttp import web
1010
from concurrent.futures import ThreadPoolExecutor
1111

@@ -255,8 +255,8 @@ def run_blocking():
255255

256256
try:
257257
return res
258-
except:
259-
print("Generate: Error while generating")
258+
except Exception as e:
259+
print(f"Generate: Error while generating {e}")
260260

261261

262262
async def send_sse_event(self, response, event, data):
@@ -273,7 +273,6 @@ async def handle_sse_stream(self, request):
273273
event_data = {"token": token}
274274
event_str = json.dumps(event_data)
275275
await self.send_sse_event(response, "message", event_str)
276-
print(event_str)
277276

278277
await asyncio.sleep(0)
279278

@@ -288,7 +287,6 @@ async def handle_request(self, request, genparams, newprompt, stream_flag):
288287

289288
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams))
290289
tasks.append(generate_task)
291-
#tasks.append(self.generate_text(newprompt, genparams))
292290

293291
try:
294292
await asyncio.gather(*tasks)
@@ -344,7 +342,7 @@ async def handle_post(self, request):
344342
body = await request.content.read()
345343
basic_api_flag = False
346344
kai_api_flag = False
347-
kai_sse_stream_flag = True
345+
kai_sse_stream_flag = False
348346
path = request.path.rstrip('/')
349347
print(request)
350348

@@ -382,10 +380,10 @@ async def handle_post(self, request):
382380

383381
gen = await self.handle_request(request, genparams, newprompt, kai_sse_stream_flag)
384382

385-
if not kai_sse_stream_flag:
386-
return web.Response(body=gen)
387-
388383
modelbusy = False
384+
385+
if not kai_sse_stream_flag:
386+
return web.Response(body=json.dumps(gen).encode())
389387
return web.Response();
390388

391389
return web.Response(status=404)
@@ -398,14 +396,19 @@ async def handle_head(self):
398396

399397
async def start_server(self):
400398

399+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
400+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
401+
sock.bind((self.addr, self.port))
402+
sock.listen(5)
403+
401404
self.app.router.add_route('GET', '/{tail:.*}', self.handle_get)
402405
self.app.router.add_route('POST', '/{tail:.*}', self.handle_post)
403406
self.app.router.add_route('OPTIONS', '/', self.handle_options)
404407
self.app.router.add_route('HEAD', '/', self.handle_head)
405408

406409
runner = web.AppRunner(self.app)
407410
await runner.setup()
408-
site = web.TCPSite(runner, self.addr, self.port)
411+
site = web.SockSite(runner, sock)
409412
await site.start()
410413

411414
# Keep Alive
@@ -415,7 +418,11 @@ async def start_server(self):
415418
except KeyboardInterrupt:
416419
await runner.cleanup()
417420
await site.stop()
418-
await exit(1)
421+
await sys.exit(0)
422+
finally:
423+
await runner.cleanup()
424+
await site.stop()
425+
await sys.exit(0)
419426

420427
async def run_server(addr, port, embedded_kailite=None):
421428
handler = ServerRequestHandler(addr, port, embedded_kailite)

0 commit comments

Comments
 (0)