-
Notifications
You must be signed in to change notification settings - Fork 4.7k
/
Copy pathapproach.py
441 lines (392 loc) · 16 KB
/
approach.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
import os
from abc import ABC
from dataclasses import dataclass
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Dict,
List,
Optional,
TypedDict,
Union,
cast,
)
from urllib.parse import urljoin
import aiohttp
from azure.search.documents.aio import SearchClient
from azure.search.documents.models import (
QueryCaptionResult,
QueryType,
VectorizedQuery,
VectorQuery,
)
from openai import AsyncOpenAI, AsyncStream
from openai.types import CompletionUsage
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessageParam,
ChatCompletionReasoningEffort,
ChatCompletionToolParam,
)
from approaches.promptmanager import PromptManager
from core.authentication import AuthenticationHelper
@dataclass
class Document:
id: Optional[str]
content: Optional[str]
embedding: Optional[List[float]]
image_embedding: Optional[List[float]]
category: Optional[str]
sourcepage: Optional[str]
sourcefile: Optional[str]
oids: Optional[List[str]]
groups: Optional[List[str]]
captions: List[QueryCaptionResult]
score: Optional[float] = None
reranker_score: Optional[float] = None
def serialize_for_results(self) -> dict[str, Any]:
result_dict = {
"id": self.id,
"content": self.content,
# Should we rename to its actual field name in the index?
"embedding": Document.trim_embedding(self.embedding),
"imageEmbedding": Document.trim_embedding(self.image_embedding),
"category": self.category,
"sourcepage": self.sourcepage,
"sourcefile": self.sourcefile,
"oids": self.oids,
"groups": self.groups,
"captions": (
[
{
"additional_properties": caption.additional_properties,
"text": caption.text,
"highlights": caption.highlights,
}
for caption in self.captions
]
if self.captions
else []
),
"score": self.score,
"reranker_score": self.reranker_score,
}
return result_dict
@classmethod
def trim_embedding(cls, embedding: Optional[List[float]]) -> Optional[str]:
"""Returns a trimmed list of floats from the vector embedding."""
if embedding:
if len(embedding) > 2:
# Format the embedding list to show the first 2 items followed by the count of the remaining items."""
return f"[{embedding[0]}, {embedding[1]} ...+{len(embedding) - 2} more]"
else:
return str(embedding)
return None
@dataclass
class ThoughtStep:
title: str
description: Optional[Any]
props: Optional[dict[str, Any]] = None
def update_token_usage(self, usage: CompletionUsage) -> None:
if self.props:
self.props["token_usage"] = TokenUsageProps.from_completion_usage(usage)
@dataclass
class DataPoints:
text: Optional[List[str]] = None
images: Optional[List] = None
@dataclass
class ExtraInfo:
data_points: DataPoints
thoughts: Optional[List[ThoughtStep]] = None
followup_questions: Optional[List[Any]] = None
@dataclass
class TokenUsageProps:
prompt_tokens: int
completion_tokens: int
reasoning_tokens: Optional[int]
total_tokens: int
@classmethod
def from_completion_usage(cls, usage: CompletionUsage) -> "TokenUsageProps":
return cls(
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
reasoning_tokens=(
usage.completion_tokens_details.reasoning_tokens if usage.completion_tokens_details else None
),
total_tokens=usage.total_tokens,
)
# GPT reasoning models don't support the same set of parameters as other models
# https://learn.microsoft.com/azure/ai-services/openai/how-to/reasoning
@dataclass
class GPTReasoningModelSupport:
streaming: bool
class Approach(ABC):
# List of GPT reasoning models support
GPT_REASONING_MODELS = {
"o1": GPTReasoningModelSupport(streaming=False),
"o3-mini": GPTReasoningModelSupport(streaming=True),
}
# Set a higher token limit for GPT reasoning models
RESPONSE_DEFAULT_TOKEN_LIMIT = 1024
RESPONSE_REASONING_DEFAULT_TOKEN_LIMIT = 8192
def __init__(
self,
search_client: SearchClient,
openai_client: AsyncOpenAI,
auth_helper: AuthenticationHelper,
query_language: Optional[str],
query_speller: Optional[str],
embedding_deployment: Optional[str], # Not needed for non-Azure OpenAI or for retrieval_mode="text"
embedding_model: str,
embedding_dimensions: int,
embedding_field: str,
openai_host: str,
vision_endpoint: str,
vision_token_provider: Callable[[], Awaitable[str]],
prompt_manager: PromptManager,
reasoning_effort: Optional[str] = None,
):
self.search_client = search_client
self.openai_client = openai_client
self.auth_helper = auth_helper
self.query_language = query_language
self.query_speller = query_speller
self.embedding_deployment = embedding_deployment
self.embedding_model = embedding_model
self.embedding_dimensions = embedding_dimensions
self.embedding_field = embedding_field
self.openai_host = openai_host
self.vision_endpoint = vision_endpoint
self.vision_token_provider = vision_token_provider
self.prompt_manager = prompt_manager
self.reasoning_effort = reasoning_effort
self.include_token_usage = True
def build_filter(self, overrides: dict[str, Any], auth_claims: dict[str, Any]) -> Optional[str]:
include_category = overrides.get("include_category")
exclude_category = overrides.get("exclude_category")
security_filter = self.auth_helper.build_security_filters(overrides, auth_claims)
filters = []
if include_category:
filters.append("category eq '{}'".format(include_category.replace("'", "''")))
if exclude_category:
filters.append("category ne '{}'".format(exclude_category.replace("'", "''")))
if security_filter:
filters.append(security_filter)
return None if len(filters) == 0 else " and ".join(filters)
async def search(
self,
top: int,
query_text: Optional[str],
filter: Optional[str],
vectors: List[VectorQuery],
use_text_search: bool,
use_vector_search: bool,
use_semantic_ranker: bool,
use_semantic_captions: bool,
minimum_search_score: Optional[float] = None,
minimum_reranker_score: Optional[float] = None,
use_query_rewriting: Optional[bool] = None,
) -> List[Document]:
search_text = query_text if use_text_search else ""
search_vectors = vectors if use_vector_search else []
if use_semantic_ranker:
results = await self.search_client.search(
search_text=search_text,
filter=filter,
top=top,
query_caption="extractive|highlight-false" if use_semantic_captions else None,
query_rewrites="generative" if use_query_rewriting else None,
vector_queries=search_vectors,
query_type=QueryType.SEMANTIC,
query_language=self.query_language,
query_speller=self.query_speller,
semantic_configuration_name="default",
semantic_query=query_text,
)
else:
results = await self.search_client.search(
search_text=search_text,
filter=filter,
top=top,
vector_queries=search_vectors,
)
documents = []
async for page in results.by_page():
async for document in page:
documents.append(
Document(
id=document.get("id"),
content=document.get("content"),
embedding=document.get(self.embedding_field),
image_embedding=document.get("imageEmbedding"),
category=document.get("category"),
sourcepage=document.get("sourcepage"),
sourcefile=document.get("sourcefile"),
oids=document.get("oids"),
groups=document.get("groups"),
captions=cast(List[QueryCaptionResult], document.get("@search.captions")),
score=document.get("@search.score"),
reranker_score=document.get("@search.reranker_score"),
)
)
qualified_documents = [
doc
for doc in documents
if (
(doc.score or 0) >= (minimum_search_score or 0)
and (doc.reranker_score or 0) >= (minimum_reranker_score or 0)
)
]
return qualified_documents
def get_sources_content(
self, results: List[Document], use_semantic_captions: bool, use_image_citation: bool
) -> list[str]:
def nonewlines(s: str) -> str:
return s.replace("\n", " ").replace("\r", " ")
if use_semantic_captions:
return [
(self.get_citation((doc.sourcepage or ""), use_image_citation))
+ ": "
+ nonewlines(" . ".join([cast(str, c.text) for c in (doc.captions or [])]))
for doc in results
]
else:
return [
(self.get_citation((doc.sourcepage or ""), use_image_citation)) + ": " + nonewlines(doc.content or "")
for doc in results
]
def get_citation(self, sourcepage: str, use_image_citation: bool) -> str:
if use_image_citation:
return sourcepage
else:
path, ext = os.path.splitext(sourcepage)
if ext.lower() == ".png":
page_idx = path.rfind("-")
page_number = int(path[page_idx + 1 :])
return f"{path[:page_idx]}.pdf#page={page_number}"
return sourcepage
async def compute_text_embedding(self, q: str):
SUPPORTED_DIMENSIONS_MODEL = {
"text-embedding-ada-002": False,
"text-embedding-3-small": True,
"text-embedding-3-large": True,
}
class ExtraArgs(TypedDict, total=False):
dimensions: int
dimensions_args: ExtraArgs = (
{"dimensions": self.embedding_dimensions} if SUPPORTED_DIMENSIONS_MODEL[self.embedding_model] else {}
)
embedding = await self.openai_client.embeddings.create(
# Azure OpenAI takes the deployment name as the model name
model=self.embedding_deployment if self.embedding_deployment else self.embedding_model,
input=q,
**dimensions_args,
)
query_vector = embedding.data[0].embedding
# TODO: use optimizations from rag time journey 3
return VectorizedQuery(vector=query_vector, k_nearest_neighbors=50, fields=self.embedding_field)
async def compute_image_embedding(self, q: str):
endpoint = urljoin(self.vision_endpoint, "computervision/retrieval:vectorizeText")
headers = {"Content-Type": "application/json"}
params = {"api-version": "2023-02-01-preview", "modelVersion": "latest"}
data = {"text": q}
headers["Authorization"] = "Bearer " + await self.vision_token_provider()
async with aiohttp.ClientSession() as session:
async with session.post(
url=endpoint, params=params, headers=headers, json=data, raise_for_status=True
) as response:
json = await response.json()
image_query_vector = json["vector"]
return VectorizedQuery(vector=image_query_vector, k_nearest_neighbors=50, fields="imageEmbedding")
def get_system_prompt_variables(self, override_prompt: Optional[str]) -> dict[str, str]:
# Allows client to replace the entire prompt, or to inject into the existing prompt using >>>
if override_prompt is None:
return {}
elif override_prompt.startswith(">>>"):
return {"injected_prompt": override_prompt[3:]}
else:
return {"override_prompt": override_prompt}
def get_response_token_limit(self, model: str, default_limit: int) -> int:
if model in self.GPT_REASONING_MODELS:
return self.RESPONSE_REASONING_DEFAULT_TOKEN_LIMIT
return default_limit
def create_chat_completion(
self,
chatgpt_deployment: Optional[str],
chatgpt_model: str,
messages: list[ChatCompletionMessageParam],
overrides: dict[str, Any],
response_token_limit: int,
should_stream: bool = False,
tools: Optional[List[ChatCompletionToolParam]] = None,
temperature: Optional[float] = None,
n: Optional[int] = None,
reasoning_effort: Optional[ChatCompletionReasoningEffort] = None,
) -> Union[Awaitable[ChatCompletion], Awaitable[AsyncStream[ChatCompletionChunk]]]:
if chatgpt_model in self.GPT_REASONING_MODELS:
params: Dict[str, Any] = {
# max_tokens is not supported
"max_completion_tokens": response_token_limit
}
# Adjust parameters for reasoning models
supported_features = self.GPT_REASONING_MODELS[chatgpt_model]
if supported_features.streaming and should_stream:
params["stream"] = True
params["stream_options"] = {"include_usage": True}
params["reasoning_effort"] = reasoning_effort or overrides.get("reasoning_effort") or self.reasoning_effort
else:
# Include parameters that may not be supported for reasoning models
params = {
"max_tokens": response_token_limit,
"temperature": temperature or overrides.get("temperature", 0.3),
}
if should_stream:
params["stream"] = True
params["stream_options"] = {"include_usage": True}
params["tools"] = tools
# Azure OpenAI takes the deployment name as the model name
return self.openai_client.chat.completions.create(
model=chatgpt_deployment if chatgpt_deployment else chatgpt_model,
messages=messages,
seed=overrides.get("seed", None),
n=n or 1,
**params,
)
def format_thought_step_for_chatcompletion(
self,
title: str,
messages: List[ChatCompletionMessageParam],
overrides: dict[str, Any],
model: str,
deployment: Optional[str],
usage: Optional[CompletionUsage] = None,
reasoning_effort: Optional[ChatCompletionReasoningEffort] = None,
) -> ThoughtStep:
properties: Dict[str, Any] = {"model": model}
if deployment:
properties["deployment"] = deployment
# Only add reasoning_effort setting if the model supports it
if model in self.GPT_REASONING_MODELS:
properties["reasoning_effort"] = reasoning_effort or overrides.get(
"reasoning_effort", self.reasoning_effort
)
if usage:
properties["token_usage"] = TokenUsageProps.from_completion_usage(usage)
return ThoughtStep(title, messages, properties)
async def run(
self,
messages: list[ChatCompletionMessageParam],
session_state: Any = None,
context: dict[str, Any] = {},
) -> dict[str, Any]:
raise NotImplementedError
async def run_stream(
self,
messages: list[ChatCompletionMessageParam],
session_state: Any = None,
context: dict[str, Any] = {},
) -> AsyncGenerator[dict[str, Any], None]:
raise NotImplementedError