-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
188 lines (159 loc) · 7.26 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import json
import traceback
import asyncio
from fastapi import FastAPI, HTTPException, Request, WebSocket, BackgroundTasks
from langchain_core.messages import HumanMessage
from error_monitor.error_monitoring_service import ErrorMonitoringService
from starlette.websockets import WebSocketState, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
monitor_service = ErrorMonitoringService()
message_queue = asyncio.Queue()
background_tasks = set()
active_websockets = set()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
async def startup_event():
print("Starting background task")
task = asyncio.create_task(process_queue())
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
await asyncio.sleep(0)
print(f"Background tasks: {len(background_tasks)}")
@app.on_event("shutdown")
async def shutdown_event():
for task in background_tasks:
task.cancel()
await asyncio.gather(*background_tasks, return_exceptions=True)
async def report_to_user(state):
message = None
if "initial_error_report" in state:
message = {"type": "error_report", "content": state["initial_error_report"]}
elif "recommendation" in state:
print(f"report_to_user - Recommendation: {state['recommendation']}")
message = {"type": "chat", "content": state["recommendation"][0]["text"]}
if message:
# Send the message to all connected WebSockets
await send_message_to_websockets(message)
if "initial_error_report" in state:
return {"messages": [HumanMessage("Error report received")]}
# Wait for a response from any client
try:
response = await asyncio.wait_for(wait_for_client_response(), timeout=300.0) # 60 second timeout
if response["type"] == "feedback":
return {"messages": [HumanMessage(response["content"])],
"user_feedback": response["content"]}
elif response["type"] == "ignore":
return {"messages": [HumanMessage("IGNORED")]}
elif response["type"] == "accept":
return {"messages": [HumanMessage("ACCEPTED")]}
else:
print(f"Unknown response type: {response['type']}")
return {"messages": [HumanMessage("IGNORED")]}
except asyncio.TimeoutError:
print("Timeout waiting for client response, defaulting to IGNORE")
return {"messages": [HumanMessage("IGNORED")]}
return {"messages": [HumanMessage("IGNORED")]}
async def send_message_to_websockets(message):
websockets_to_remove = set()
for websocket in active_websockets:
try:
if websocket.application_state == WebSocketState.CONNECTED:
await websocket.send_json(message)
else:
websockets_to_remove.add(websocket)
except Exception as e:
print(f"Error sending message to WebSocket: {e}")
websockets_to_remove.add(websocket)
active_websockets.difference_update(websockets_to_remove)
async def update_system_status(state):
status = state.get("status", None)
message = None
if status == 'ACCEPTED':
message = {"type": "status_update",
"content": f"Implementing the following plan: {state['correction_plan'][0]['text']}"}
if status in ['IGNORED', 'CORRECTIONS_COMPLETE']:
print(f"update_system_status - Message: {state['messages'][-1]}")
message = {"type": "status_update", "content": state['messages'][-1].content[0]['text']}
if message:
print(f"Sending status update: {message}")
await send_message_to_websockets(message)
return {"status": "CONTINUE" if status == 'ACCEPTED' else "END"}
async def wait_for_client_response():
response_future = asyncio.Future()
await message_queue.put(("wait_response", response_future))
return await response_future
@app.post("/error-report")
async def receive_error_report(request: Request, background_tasks: BackgroundTasks):
try:
payload = await request.json()
print(f"Received payload: {payload}")
error_report_location = payload.get("location", "")
print(f"Error report location: {error_report_location}")
# Process the error report in the background
background_tasks.add_task(monitor_service.process_error_report, error_report_location,
report_to_user,
update_system_status)
return {"message": "Error report received", "location": error_report_location}
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON")
except Exception as e:
error_report = traceback.format_exc()
print(f"Error: {error_report}")
raise HTTPException(status_code=500, detail=str(e))
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
active_websockets.add(websocket)
print(f"WebSocket connection established. Active connections: {len(active_websockets)}")
try:
while True:
data = await websocket.receive_text()
try:
message = json.loads(data)
if message["type"] in ["feedback", "ignore", "accept"]:
# This is a response to a previous message
await message_queue.put(message)
else:
await websocket.send_json({"type": "error", "content": "Unknown message type"})
except json.JSONDecodeError:
await websocket.send_json({"type": "error", "content": "Invalid JSON format"})
except WebSocketDisconnect:
print("WebSocket disconnected")
finally:
active_websockets.remove(websocket)
print(f"WebSocket connection closed. Active connections: {len(active_websockets)}")
async def process_queue():
print("Starting message processing loop")
while True:
print("Waiting for message")
item = await message_queue.get()
if isinstance(item, tuple) and item[0] == "wait_response":
response_future = item[1]
if not response_future.done():
# Wait for the next message, which should be the response
response = await message_queue.get()
if not response_future.done():
response_future.set_result(response)
else:
print("Future was completed before setting the result, skipping.")
else:
print("Attempted to set result on a done future, skipping.")
elif isinstance(item, dict):
# Handle other message types here
message_type = item.get("type")
if message_type in ["feedback", "ignore", "accept"]:
# Process based on the message type
print(f"Processing message type: {message_type}")
# Example: You might want to add specific handling for each type
else:
print(f"Unknown message type received: {message_type}")
else:
print("Received an item of unknown type")
message_queue.task_done()