-
Notifications
You must be signed in to change notification settings - Fork 4.7k
/
Copy pathchatreadretrieveread.py
202 lines (184 loc) · 8.58 KB
/
chatreadretrieveread.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
from typing import Any, Awaitable, List, Optional, Union, cast
from azure.search.documents.aio import SearchClient
from azure.search.documents.models import VectorQuery
from openai import AsyncOpenAI, AsyncStream
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessageParam,
ChatCompletionToolParam,
)
from approaches.approach import DataPoints, ExtraInfo, ThoughtStep
from approaches.chatapproach import ChatApproach
from approaches.promptmanager import PromptManager
from core.authentication import AuthenticationHelper
class ChatReadRetrieveReadApproach(ChatApproach):
"""
A multi-step approach that first uses OpenAI to turn the user's question into a search query,
then uses Azure AI Search to retrieve relevant documents, and then sends the conversation history,
original user question, and search results to OpenAI to generate a response.
"""
def __init__(
self,
*,
search_client: SearchClient,
auth_helper: AuthenticationHelper,
openai_client: AsyncOpenAI,
chatgpt_model: str,
chatgpt_deployment: Optional[str], # Not needed for non-Azure OpenAI
embedding_deployment: Optional[str], # Not needed for non-Azure OpenAI or for retrieval_mode="text"
embedding_model: str,
embedding_dimensions: int,
embedding_field: str,
sourcepage_field: str,
content_field: str,
query_language: str,
query_speller: str,
prompt_manager: PromptManager,
reasoning_effort: Optional[str] = None,
):
self.search_client = search_client
self.openai_client = openai_client
self.auth_helper = auth_helper
self.chatgpt_model = chatgpt_model
self.chatgpt_deployment = chatgpt_deployment
self.embedding_deployment = embedding_deployment
self.embedding_model = embedding_model
self.embedding_dimensions = embedding_dimensions
self.embedding_field = embedding_field
self.sourcepage_field = sourcepage_field
self.content_field = content_field
self.query_language = query_language
self.query_speller = query_speller
self.prompt_manager = prompt_manager
self.query_rewrite_prompt = self.prompt_manager.load_prompt("chat_query_rewrite.prompty")
self.query_rewrite_tools = self.prompt_manager.load_tools("chat_query_rewrite_tools.json")
self.answer_prompt = self.prompt_manager.load_prompt("chat_answer_question.prompty")
self.reasoning_effort = reasoning_effort
self.include_token_usage = True
async def run_until_final_call(
self,
messages: list[ChatCompletionMessageParam],
overrides: dict[str, Any],
auth_claims: dict[str, Any],
should_stream: bool = False,
) -> tuple[ExtraInfo, Union[Awaitable[ChatCompletion], Awaitable[AsyncStream[ChatCompletionChunk]]]]:
use_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
use_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
use_semantic_ranker = True if overrides.get("semantic_ranker") else False
use_semantic_captions = True if overrides.get("semantic_captions") else False
use_query_rewriting = True if overrides.get("query_rewriting") else False
top = overrides.get("top", 3)
minimum_search_score = overrides.get("minimum_search_score", 0.0)
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)
filter = self.build_filter(overrides, auth_claims)
original_user_query = messages[-1]["content"]
if not isinstance(original_user_query, str):
raise ValueError("The most recent message content must be a string.")
reasoning_model_support = self.GPT_REASONING_MODELS.get(self.chatgpt_model)
if reasoning_model_support and (not reasoning_model_support.streaming and should_stream):
raise Exception(
f"{self.chatgpt_model} does not support streaming. Please use a different model or disable streaming."
)
query_messages = self.prompt_manager.render_prompt(
self.query_rewrite_prompt, {"user_query": original_user_query, "past_messages": messages[:-1]}
)
tools: List[ChatCompletionToolParam] = self.query_rewrite_tools
# STEP 1: Generate an optimized keyword search query based on the chat history and the last question
chat_completion = cast(
ChatCompletion,
await self.create_chat_completion(
self.chatgpt_deployment,
self.chatgpt_model,
messages=query_messages,
overrides=overrides,
response_token_limit=self.get_response_token_limit(
self.chatgpt_model, 100
), # Setting too low risks malformed JSON, setting too high may affect performance
temperature=0.0, # Minimize creativity for search query generation
tools=tools,
reasoning_effort="low", # Minimize reasoning for search query generation
),
)
query_text = self.get_search_query(chat_completion, original_user_query)
# STEP 2: Retrieve relevant documents from the search index with the GPT optimized query
# If retrieval mode includes vectors, compute an embedding for the query
vectors: list[VectorQuery] = []
if use_vector_search:
vectors.append(await self.compute_text_embedding(query_text))
results = await self.search(
top,
query_text,
filter,
vectors,
use_text_search,
use_vector_search,
use_semantic_ranker,
use_semantic_captions,
minimum_search_score,
minimum_reranker_score,
use_query_rewriting,
)
# STEP 3: Generate a contextual and content specific answer using the search results and chat history
text_sources = self.get_sources_content(results, use_semantic_captions, use_image_citation=False)
messages = self.prompt_manager.render_prompt(
self.answer_prompt,
self.get_system_prompt_variables(overrides.get("prompt_template"))
| {
"include_follow_up_questions": bool(overrides.get("suggest_followup_questions")),
"past_messages": messages[:-1],
"user_query": original_user_query,
"text_sources": text_sources,
},
)
extra_info = ExtraInfo(
DataPoints(text=text_sources),
thoughts=[
self.format_thought_step_for_chatcompletion(
title="Prompt to generate search query",
messages=query_messages,
overrides=overrides,
model=self.chatgpt_model,
deployment=self.chatgpt_deployment,
usage=chat_completion.usage,
reasoning_effort="low",
),
ThoughtStep(
"Search using generated search query",
query_text,
{
"use_semantic_captions": use_semantic_captions,
"use_semantic_ranker": use_semantic_ranker,
"use_query_rewriting": use_query_rewriting,
"top": top,
"filter": filter,
"use_vector_search": use_vector_search,
"use_text_search": use_text_search,
},
),
ThoughtStep(
"Search results",
[result.serialize_for_results() for result in results],
),
self.format_thought_step_for_chatcompletion(
title="Prompt to generate answer",
messages=messages,
overrides=overrides,
model=self.chatgpt_model,
deployment=self.chatgpt_deployment,
usage=None,
),
],
)
chat_coroutine = cast(
Union[Awaitable[ChatCompletion], Awaitable[AsyncStream[ChatCompletionChunk]]],
self.create_chat_completion(
self.chatgpt_deployment,
self.chatgpt_model,
messages,
overrides,
self.get_response_token_limit(self.chatgpt_model, 1024),
should_stream,
),
)
return (extra_info, chat_coroutine)