-
Notifications
You must be signed in to change notification settings - Fork 4.7k
/
Copy pathfilestrategy.py
154 lines (138 loc) · 6.41 KB
/
filestrategy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import logging
from typing import List, Optional
from azure.core.credentials import AzureKeyCredential
from .blobmanager import BlobManager
from .embeddings import ImageEmbeddings, OpenAIEmbeddings
from .fileprocessor import FileProcessor
from .listfilestrategy import File, ListFileStrategy
from .mediadescriber import ContentUnderstandingDescriber
from .searchmanager import SearchManager, Section
from .strategy import DocumentAction, SearchInfo, Strategy
logger = logging.getLogger("scripts")
async def parse_file(
file: File,
file_processors: dict[str, FileProcessor],
category: Optional[str] = None,
image_embeddings: Optional[ImageEmbeddings] = None,
) -> Optional[List[Section]]:
key = file.file_extension().lower()
processor = file_processors.get(key)
if processor is None:
logger.info("Skipping '%s', no parser found.", file.filename())
return []
logger.info("Ingesting '%s'", file.filename())
try:
pages = [page async for page in processor.parser.parse(content=file.content)]
except:
logger.exception("There was a problem parsing the file %s, skipping...", file.filename())
return None
logger.info("Splitting '%s' into sections", file.filename())
if image_embeddings:
logger.warning("Each page will be split into smaller chunks of text, but images will be of the entire page.")
sections = [
Section(split_page, content=file, category=category) for split_page in processor.splitter.split_pages(pages)
]
return sections
class FileStrategy(Strategy):
"""
Strategy for ingesting documents into a search service from files stored either locally or in a data lake storage account
"""
def __init__(
self,
list_file_strategy: ListFileStrategy,
blob_manager: BlobManager,
search_info: SearchInfo,
file_processors: dict[str, FileProcessor],
document_action: DocumentAction = DocumentAction.Add,
embeddings: Optional[OpenAIEmbeddings] = None,
image_embeddings: Optional[ImageEmbeddings] = None,
search_analyzer_name: Optional[str] = None,
use_acls: bool = False,
category: Optional[str] = None,
use_content_understanding: bool = False,
content_understanding_endpoint: Optional[str] = None,
):
self.list_file_strategy = list_file_strategy
self.blob_manager = blob_manager
self.file_processors = file_processors
self.document_action = document_action
self.embeddings = embeddings
self.image_embeddings = image_embeddings
self.search_analyzer_name = search_analyzer_name
self.search_info = search_info
self.use_acls = use_acls
self.category = category
self.use_content_understanding = use_content_understanding
self.content_understanding_endpoint = content_understanding_endpoint
async def setup(self):
search_manager = SearchManager(
self.search_info,
self.search_analyzer_name,
self.use_acls,
False,
self.embeddings,
search_images=self.image_embeddings is not None,
)
await search_manager.create_index()
if self.use_content_understanding:
if self.content_understanding_endpoint is None:
raise ValueError("Content Understanding is enabled but no endpoint was provided")
if isinstance(self.search_info.credential, AzureKeyCredential):
raise ValueError(
"AzureKeyCredential is not supported for Content Understanding, use keyless auth instead"
)
cu_manager = ContentUnderstandingDescriber(self.content_understanding_endpoint, self.search_info.credential)
await cu_manager.create_analyzer()
async def run(self):
search_manager = SearchManager(
self.search_info, self.search_analyzer_name, self.use_acls, False, self.embeddings
)
if self.document_action == DocumentAction.Add:
files = self.list_file_strategy.list()
async for file in files:
try:
sections = await parse_file(file, self.file_processors, self.category, self.image_embeddings)
if sections:
blob_sas_uris = await self.blob_manager.upload_blob(file)
blob_image_embeddings: Optional[List[List[float]]] = None
if self.image_embeddings and blob_sas_uris:
blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris)
await search_manager.update_content(sections, blob_image_embeddings, url=file.url)
finally:
if file:
file.close()
elif self.document_action == DocumentAction.Remove:
paths = self.list_file_strategy.list_paths()
async for path in paths:
await self.blob_manager.remove_blob(path)
await search_manager.remove_content(path)
elif self.document_action == DocumentAction.RemoveAll:
await self.blob_manager.remove_blob()
await search_manager.remove_content()
class UploadUserFileStrategy:
"""
Strategy for ingesting a file that has already been uploaded to a ADLS2 storage account
"""
def __init__(
self,
search_info: SearchInfo,
file_processors: dict[str, FileProcessor],
embeddings: Optional[OpenAIEmbeddings] = None,
image_embeddings: Optional[ImageEmbeddings] = None,
):
self.file_processors = file_processors
self.embeddings = embeddings
self.image_embeddings = image_embeddings
self.search_info = search_info
self.search_manager = SearchManager(self.search_info, None, True, False, self.embeddings)
async def add_file(self, file: File):
if self.image_embeddings:
logging.warning("Image embeddings are not currently supported for the user upload feature")
sections = await parse_file(file, self.file_processors)
if sections:
await self.search_manager.update_content(sections, url=file.url)
async def remove_file(self, filename: str, oid: str):
if filename is None or filename == "":
logging.warning("Filename is required to remove a file")
return
await self.search_manager.remove_content(filename, oid)