-
Notifications
You must be signed in to change notification settings - Fork 675
/
Copy pathmessaging.py
165 lines (145 loc) · 5.29 KB
/
messaging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from typing import Dict, Any, Optional, List
import asyncio
import queue
import logging
from uuid import uuid4
from anyio.streams.memory import MemoryObjectSendStream
from llama_index.callbacks.base import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.query_engine.sub_question_query_engine import SubQuestionAnswerPair
from llama_index.agent.openai_agent import StreamingAgentChatResponse
from pydantic import BaseModel
from app import schema
from app.schema import SubProcessMetadataKeysEnum, SubProcessMetadataMap
from app.models.db import MessageSubProcessSourceEnum
from app.chat.engine import get_chat_engine
logger = logging.getLogger(__name__)
class StreamedMessage(BaseModel):
content: str
class StreamedMessageSubProcess(BaseModel):
source: MessageSubProcessSourceEnum
has_ended: bool
event_id: str
metadata_map: Optional[SubProcessMetadataMap]
class ChatCallbackHandler(BaseCallbackHandler):
def __init__(
self,
send_chan: MemoryObjectSendStream,
):
"""Initialize the base callback handler."""
ignored_events = [CBEventType.CHUNKING, CBEventType.NODE_PARSING]
super().__init__(ignored_events, ignored_events)
self._send_chan = send_chan
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> str:
"""Create the MessageSubProcess row for the event that started."""
asyncio.create_task(
self.async_on_event(
event_type, payload, event_id, is_start_event=True, **kwargs
)
)
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Create the MessageSubProcess row for the event that completed."""
asyncio.create_task(
self.async_on_event(
event_type, payload, event_id, is_start_event=False, **kwargs
)
)
def get_metadata_from_event(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
is_start_event: bool = False,
) -> SubProcessMetadataMap:
metadata_map = {}
if (
event_type == CBEventType.SUB_QUESTION
and EventPayload.SUB_QUESTION in payload
):
sub_q: SubQuestionAnswerPair = payload[EventPayload.SUB_QUESTION]
metadata_map[
SubProcessMetadataKeysEnum.SUB_QUESTION.value
] = schema.QuestionAnswerPair.from_sub_question_answer_pair(sub_q).dict()
return metadata_map
async def async_on_event(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
is_start_event: bool = False,
**kwargs: Any,
) -> None:
metadata_map = self.get_metadata_from_event(
event_type, payload=payload, is_start_event=is_start_event
)
metadata_map = metadata_map or None
source = MessageSubProcessSourceEnum[event_type.name]
if self._send_chan._closed:
logger.debug("Received event after send channel closed. Ignoring.")
return
await self._send_chan.send(
StreamedMessageSubProcess(
source=source,
metadata_map=metadata_map,
event_id=event_id,
has_ended=not is_start_event,
)
)
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op."""
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""No-op."""
async def handle_chat_message(
conversation: schema.Conversation,
user_message: schema.UserMessageCreate,
send_chan: MemoryObjectSendStream,
) -> None:
async with send_chan:
chat_engine = await get_chat_engine(
ChatCallbackHandler(send_chan), conversation
)
await send_chan.send(
StreamedMessageSubProcess(
event_id=str(uuid4()),
has_ended=True,
source=MessageSubProcessSourceEnum.CONSTRUCTED_QUERY_ENGINE,
)
)
logger.debug("Engine received")
templated_message = f"""
Remember - if I have asked a relevant financial question, use your tools.
{user_message.content}
""".strip()
streaming_chat_response: StreamingAgentChatResponse = (
await chat_engine.astream_chat(templated_message)
)
response_str = ""
async for text in streaming_chat_response.async_response_gen():
response_str += text
if send_chan._closed:
logger.debug(
"Received streamed token after send channel closed. Ignoring."
)
return
await send_chan.send(StreamedMessage(content=response_str))
if response_str.strip() == "":
await send_chan.send(
StreamedMessage(
content="Sorry, I either wasn't able to understand your question or I don't have an answer for it."
)
)