-
Notifications
You must be signed in to change notification settings - Fork 4.7k
/
Copy pathretrievethenreadvision.py
183 lines (168 loc) · 7.25 KB
/
retrievethenreadvision.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from typing import Any, Awaitable, Callable, Optional
from azure.search.documents.aio import SearchClient
from azure.storage.blob.aio import ContainerClient
from openai import AsyncOpenAI
from openai.types.chat import (
ChatCompletionMessageParam,
)
from approaches.approach import Approach, DataPoints, ExtraInfo, ThoughtStep
from approaches.promptmanager import PromptManager
from core.authentication import AuthenticationHelper
from core.imageshelper import fetch_image
class RetrieveThenReadVisionApproach(Approach):
"""
Simple retrieve-then-read implementation, using the AI Search and OpenAI APIs directly. It first retrieves
top documents including images from search, then constructs a prompt with them, and then uses OpenAI to generate an completion
(answer) with that prompt.
"""
def __init__(
self,
*,
search_client: SearchClient,
blob_container_client: ContainerClient,
openai_client: AsyncOpenAI,
auth_helper: AuthenticationHelper,
gpt4v_deployment: Optional[str],
gpt4v_model: str,
embedding_deployment: Optional[str], # Not needed for non-Azure OpenAI or for retrieval_mode="text"
embedding_model: str,
embedding_dimensions: int,
embedding_field: str,
sourcepage_field: str,
content_field: str,
query_language: str,
query_speller: str,
vision_endpoint: str,
vision_token_provider: Callable[[], Awaitable[str]],
prompt_manager: PromptManager,
):
self.search_client = search_client
self.blob_container_client = blob_container_client
self.openai_client = openai_client
self.auth_helper = auth_helper
self.embedding_model = embedding_model
self.embedding_deployment = embedding_deployment
self.embedding_dimensions = embedding_dimensions
self.embedding_field = embedding_field
self.sourcepage_field = sourcepage_field
self.content_field = content_field
self.gpt4v_deployment = gpt4v_deployment
self.gpt4v_model = gpt4v_model
self.query_language = query_language
self.query_speller = query_speller
self.vision_endpoint = vision_endpoint
self.vision_token_provider = vision_token_provider
self.prompt_manager = prompt_manager
self.answer_prompt = self.prompt_manager.load_prompt("ask_answer_question_vision.prompty")
# Currently disabled due to issues with rendering token usage in the UI
self.include_token_usage = False
async def run(
self,
messages: list[ChatCompletionMessageParam],
session_state: Any = None,
context: dict[str, Any] = {},
) -> dict[str, Any]:
q = messages[-1]["content"]
if not isinstance(q, str):
raise ValueError("The most recent message content must be a string.")
overrides = context.get("overrides", {})
seed = overrides.get("seed", None)
auth_claims = context.get("auth_claims", {})
use_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
use_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
use_semantic_ranker = True if overrides.get("semantic_ranker") else False
use_query_rewriting = True if overrides.get("query_rewriting") else False
use_semantic_captions = True if overrides.get("semantic_captions") else False
top = overrides.get("top", 3)
minimum_search_score = overrides.get("minimum_search_score", 0.0)
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)
filter = self.build_filter(overrides, auth_claims)
vector_fields = overrides.get("vector_fields", [self.embedding_field])
send_text_to_gptvision = overrides.get("gpt4v_input") in ["textAndImages", "texts", None]
send_images_to_gptvision = overrides.get("gpt4v_input") in ["textAndImages", "images", None]
# If retrieval mode includes vectors, compute an embedding for the query
vectors = []
if use_vector_search:
for field in vector_fields:
vector = (
await self.compute_image_embedding(q)
if field.startswith("image")
else await self.compute_text_embedding(q)
)
vectors.append(vector)
results = await self.search(
top,
q,
filter,
vectors,
use_text_search,
use_vector_search,
use_semantic_ranker,
use_semantic_captions,
minimum_search_score,
minimum_reranker_score,
use_query_rewriting,
)
# Process results
text_sources = []
image_sources = []
if send_text_to_gptvision:
text_sources = self.get_sources_content(results, use_semantic_captions, use_image_citation=True)
if send_images_to_gptvision:
for result in results:
url = await fetch_image(self.blob_container_client, result)
if url:
image_sources.append(url)
messages = self.prompt_manager.render_prompt(
self.answer_prompt,
self.get_system_prompt_variables(overrides.get("prompt_template"))
| {"user_query": q, "text_sources": text_sources, "image_sources": image_sources},
)
chat_completion = await self.openai_client.chat.completions.create(
model=self.gpt4v_deployment if self.gpt4v_deployment else self.gpt4v_model,
messages=messages,
temperature=overrides.get("temperature", 0.3),
max_tokens=1024,
n=1,
seed=seed,
)
extra_info = ExtraInfo(
DataPoints(text=text_sources, images=image_sources),
[
ThoughtStep(
"Search using user query",
q,
{
"use_semantic_captions": use_semantic_captions,
"use_semantic_ranker": use_semantic_ranker,
"use_query_rewriting": use_query_rewriting,
"top": top,
"filter": filter,
"vector_fields": vector_fields,
"use_vector_search": use_vector_search,
"use_text_search": use_text_search,
},
),
ThoughtStep(
"Search results",
[result.serialize_for_results() for result in results],
),
ThoughtStep(
"Prompt to generate answer",
messages,
(
{"model": self.gpt4v_model, "deployment": self.gpt4v_deployment}
if self.gpt4v_deployment
else {"model": self.gpt4v_model}
),
),
],
)
return {
"message": {
"content": chat_completion.choices[0].message.content,
"role": chat_completion.choices[0].message.role,
},
"context": extra_info,
"session_state": session_state,
}