-
Notifications
You must be signed in to change notification settings - Fork 4.7k
/
Copy pathfilestrategy.py
170 lines (153 loc) · 7.24 KB
/
filestrategy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import logging
from typing import List, Optional
from azure.core.credentials import AzureKeyCredential
from .blobmanager import BlobManager
from .embeddings import ImageEmbeddings, OpenAIEmbeddings
from .fileprocessor import FileProcessor
from .listfilestrategy import File, ListFileStrategy
from .mediadescriber import ContentUnderstandingDescriber
from .searchmanager import SearchManager, Section
from .strategy import DocumentAction, SearchInfo, Strategy
logger = logging.getLogger("scripts")
async def parse_file(
file: File,
file_processors: dict[str, FileProcessor],
category: Optional[str] = None,
image_embeddings: Optional[ImageEmbeddings] = None,
) -> List[Section]:
key = file.file_extension().lower()
processor = file_processors.get(key)
if processor is None:
logger.info("Skipping '%s', no parser found.", file.filename())
return []
logger.info("Ingesting '%s'", file.filename())
pages = [page async for page in processor.parser.parse(content=file.content)]
logger.info("Splitting '%s' into sections", file.filename())
if image_embeddings:
logger.warning("Each page will be split into smaller chunks of text, but images will be of the entire page.")
sections = [
Section(split_page, content=file, category=category) for split_page in processor.splitter.split_pages(pages)
]
return sections
class FileStrategy(Strategy):
"""
Strategy for ingesting documents into a search service from files stored either locally or in a data lake storage account
"""
def __init__(
self,
list_file_strategy: ListFileStrategy,
blob_manager: BlobManager,
search_info: SearchInfo,
file_processors: dict[str, FileProcessor],
document_action: DocumentAction = DocumentAction.Add,
embeddings: Optional[OpenAIEmbeddings] = None,
image_embeddings: Optional[ImageEmbeddings] = None,
search_analyzer_name: Optional[str] = None,
search_field_name_embedding: Optional[str] = None,
search_field_name_image_embedding: Optional[str] = None,
use_acls: bool = False,
category: Optional[str] = None,
use_content_understanding: bool = False,
content_understanding_endpoint: Optional[str] = None,
):
self.list_file_strategy = list_file_strategy
self.blob_manager = blob_manager
self.file_processors = file_processors
self.document_action = document_action
self.embeddings = embeddings
self.image_embeddings = image_embeddings
self.search_analyzer_name = search_analyzer_name
self.search_field_name_embedding = search_field_name_embedding
self.search_field_name_image_embedding = search_field_name_image_embedding
self.search_info = search_info
self.use_acls = use_acls
self.category = category
self.use_content_understanding = use_content_understanding
self.content_understanding_endpoint = content_understanding_endpoint
def setup_search_manager(self):
self.search_manager = SearchManager(
self.search_info,
self.search_analyzer_name,
self.use_acls,
False,
self.embeddings,
field_name_embedding=self.search_field_name_embedding,
field_name_image_embedding=self.search_field_name_image_embedding,
search_images=self.image_embeddings is not None,
)
async def setup(self):
self.setup_search_manager()
await self.search_manager.create_index()
if self.use_content_understanding:
if self.content_understanding_endpoint is None:
raise ValueError("Content Understanding is enabled but no endpoint was provided")
if isinstance(self.search_info.credential, AzureKeyCredential):
raise ValueError(
"AzureKeyCredential is not supported for Content Understanding, use keyless auth instead"
)
cu_manager = ContentUnderstandingDescriber(self.content_understanding_endpoint, self.search_info.credential)
await cu_manager.create_analyzer()
async def run(self):
self.setup_search_manager()
if self.document_action == DocumentAction.Add:
files = self.list_file_strategy.list()
async for file in files:
try:
sections = await parse_file(file, self.file_processors, self.category, self.image_embeddings)
if sections:
blob_sas_uris = await self.blob_manager.upload_blob(file)
blob_image_embeddings: Optional[List[List[float]]] = None
if self.image_embeddings and blob_sas_uris:
blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris)
await self.search_manager.update_content(sections, blob_image_embeddings, url=file.url)
finally:
if file:
file.close()
elif self.document_action == DocumentAction.Remove:
paths = self.list_file_strategy.list_paths()
async for path in paths:
await self.blob_manager.remove_blob(path)
await self.search_manager.remove_content(path)
elif self.document_action == DocumentAction.RemoveAll:
await self.blob_manager.remove_blob()
await self.search_manager.remove_content()
class UploadUserFileStrategy:
"""
Strategy for ingesting a file that has already been uploaded to a ADLS2 storage account
"""
def __init__(
self,
search_info: SearchInfo,
file_processors: dict[str, FileProcessor],
embeddings: Optional[OpenAIEmbeddings] = None,
image_embeddings: Optional[ImageEmbeddings] = None,
search_field_name_embedding: Optional[str] = None,
search_field_name_image_embedding: Optional[str] = None,
):
self.file_processors = file_processors
self.embeddings = embeddings
self.image_embeddings = image_embeddings
self.search_info = search_info
self.search_manager = SearchManager(
search_info=self.search_info,
search_analyzer_name=None,
use_acls=True,
use_int_vectorization=False,
embeddings=self.embeddings,
field_name_embedding=search_field_name_embedding,
field_name_image_embedding=search_field_name_image_embedding,
search_images=False,
)
self.search_field_name_embedding = search_field_name_embedding
self.search_field_name_image_embedding = search_field_name_image_embedding
async def add_file(self, file: File):
if self.image_embeddings:
logging.warning("Image embeddings are not currently supported for the user upload feature")
sections = await parse_file(file, self.file_processors)
if sections:
await self.search_manager.update_content(sections, url=file.url)
async def remove_file(self, filename: str, oid: str):
if filename is None or filename == "":
logging.warning("Filename is required to remove a file")
return
await self.search_manager.remove_content(filename, oid)