-
Notifications
You must be signed in to change notification settings - Fork 4.7k
/
Copy pathretrievethenread.py
155 lines (143 loc) · 6.11 KB
/
retrievethenread.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from typing import Any, Optional, cast
from azure.search.documents.aio import SearchClient
from azure.search.documents.models import VectorQuery
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
from approaches.approach import Approach, DataPoints, ExtraInfo, ThoughtStep
from approaches.promptmanager import PromptManager
from core.authentication import AuthenticationHelper
class RetrieveThenReadApproach(Approach):
"""
Simple retrieve-then-read implementation, using the AI Search and OpenAI APIs directly. It first retrieves
top documents from search, then constructs a prompt with them, and then uses OpenAI to generate an completion
(answer) with that prompt.
"""
def __init__(
self,
*,
search_client: SearchClient,
auth_helper: AuthenticationHelper,
openai_client: AsyncOpenAI,
chatgpt_model: str,
chatgpt_deployment: Optional[str], # Not needed for non-Azure OpenAI
embedding_model: str,
embedding_deployment: Optional[str], # Not needed for non-Azure OpenAI or for retrieval_mode="text"
embedding_dimensions: int,
embedding_field: str,
sourcepage_field: str,
content_field: str,
query_language: str,
query_speller: str,
prompt_manager: PromptManager,
reasoning_effort: Optional[str] = None,
):
self.search_client = search_client
self.chatgpt_deployment = chatgpt_deployment
self.openai_client = openai_client
self.auth_helper = auth_helper
self.chatgpt_model = chatgpt_model
self.embedding_model = embedding_model
self.embedding_dimensions = embedding_dimensions
self.chatgpt_deployment = chatgpt_deployment
self.embedding_deployment = embedding_deployment
self.embedding_field = embedding_field
self.sourcepage_field = sourcepage_field
self.content_field = content_field
self.query_language = query_language
self.query_speller = query_speller
self.prompt_manager = prompt_manager
self.answer_prompt = self.prompt_manager.load_prompt("ask_answer_question.prompty")
self.reasoning_effort = reasoning_effort
self.include_token_usage = True
async def run(
self,
messages: list[ChatCompletionMessageParam],
session_state: Any = None,
context: dict[str, Any] = {},
) -> dict[str, Any]:
q = messages[-1]["content"]
if not isinstance(q, str):
raise ValueError("The most recent message content must be a string.")
overrides = context.get("overrides", {})
auth_claims = context.get("auth_claims", {})
use_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
use_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
use_semantic_ranker = True if overrides.get("semantic_ranker") else False
use_query_rewriting = True if overrides.get("query_rewriting") else False
use_semantic_captions = True if overrides.get("semantic_captions") else False
top = overrides.get("top", 3)
minimum_search_score = overrides.get("minimum_search_score", 0.0)
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)
filter = self.build_filter(overrides, auth_claims)
# If retrieval mode includes vectors, compute an embedding for the query
vectors: list[VectorQuery] = []
if use_vector_search:
vectors.append(await self.compute_text_embedding(q))
results = await self.search(
top,
q,
filter,
vectors,
use_text_search,
use_vector_search,
use_semantic_ranker,
use_semantic_captions,
minimum_search_score,
minimum_reranker_score,
use_query_rewriting,
)
# Process results
text_sources = self.get_sources_content(results, use_semantic_captions, use_image_citation=False)
messages = self.prompt_manager.render_prompt(
self.answer_prompt,
self.get_system_prompt_variables(overrides.get("prompt_template"))
| {"user_query": q, "text_sources": text_sources},
)
chat_completion = cast(
ChatCompletion,
await self.create_chat_completion(
self.chatgpt_deployment,
self.chatgpt_model,
messages=messages,
overrides=overrides,
response_token_limit=self.get_response_token_limit(self.chatgpt_model, 1024),
),
)
extra_info = ExtraInfo(
DataPoints(text=text_sources),
thoughts=[
ThoughtStep(
"Search using user query",
q,
{
"use_semantic_captions": use_semantic_captions,
"use_semantic_ranker": use_semantic_ranker,
"use_query_rewriting": use_query_rewriting,
"top": top,
"filter": filter,
"use_vector_search": use_vector_search,
"use_text_search": use_text_search,
},
),
ThoughtStep(
"Search results",
[result.serialize_for_results() for result in results],
),
self.format_thought_step_for_chatcompletion(
title="Prompt to generate answer",
messages=messages,
overrides=overrides,
model=self.chatgpt_model,
deployment=self.chatgpt_deployment,
usage=chat_completion.usage,
),
],
)
return {
"message": {
"content": chat_completion.choices[0].message.content,
"role": chat_completion.choices[0].message.role,
},
"context": extra_info,
"session_state": session_state,
}