-
Notifications
You must be signed in to change notification settings - Fork 568
/
Copy pathpost_processing.py
239 lines (205 loc) · 9.77 KB
/
post_processing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
from neo4j import GraphDatabase
import logging
import time
from langchain_neo4j import Neo4jGraph
import os
from src.shared.common_fn import load_embedding_model
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate
from src.shared.constants import GRAPH_CLEANUP_PROMPT
from src.llm import get_llm
from src.graphDB_dataAccess import graphDBdataAccess
import time
DROP_INDEX_QUERY = "DROP INDEX entities IF EXISTS;"
LABELS_QUERY = "CALL db.labels()"
FULL_TEXT_QUERY = "CREATE FULLTEXT INDEX entities FOR (n{labels_str}) ON EACH [n.id, n.description];"
FILTER_LABELS = ["Chunk","Document","__Community__"]
HYBRID_SEARCH_INDEX_DROP_QUERY = "DROP INDEX keyword IF EXISTS;"
HYBRID_SEARCH_FULL_TEXT_QUERY = "CREATE FULLTEXT INDEX keyword FOR (n:Chunk) ON EACH [n.text]"
COMMUNITY_INDEX_DROP_QUERY = "DROP INDEX community_keyword IF EXISTS;"
COMMUNITY_INDEX_FULL_TEXT_QUERY = "CREATE FULLTEXT INDEX community_keyword FOR (n:`__Community__`) ON EACH [n.summary]"
CHUNK_VECTOR_INDEX_NAME = "vector"
CHUNK_VECTOR_EMBEDDING_DIMENSION = 384
DROP_CHUNK_VECTOR_INDEX_QUERY = f"DROP INDEX {CHUNK_VECTOR_INDEX_NAME} IF EXISTS;"
CREATE_CHUNK_VECTOR_INDEX_QUERY = """
CREATE VECTOR INDEX {index_name} IF NOT EXISTS FOR (c:Chunk) ON c.embedding
OPTIONS {{
indexConfig: {{
`vector.dimensions`: {embedding_dimension},
`vector.similarity_function`: 'cosine'
}}
}}
"""
def create_vector_index(driver, index_type, embedding_dimension=None):
drop_query = ""
query = ""
if index_type == CHUNK_VECTOR_INDEX_NAME:
drop_query = DROP_CHUNK_VECTOR_INDEX_QUERY
query = CREATE_CHUNK_VECTOR_INDEX_QUERY.format(
index_name=CHUNK_VECTOR_INDEX_NAME,
embedding_dimension=embedding_dimension if embedding_dimension else CHUNK_VECTOR_EMBEDDING_DIMENSION
)
else:
logging.error(f"Invalid index type provided: {index_type}")
return
try:
logging.info("Starting the process to create vector index.")
with driver.session() as session:
try:
start_step = time.time()
session.run(drop_query)
logging.info(f"Dropped existing index (if any) in {time.time() - start_step:.2f} seconds.")
except Exception as e:
logging.error(f"Failed to drop index: {e}")
return
try:
start_step = time.time()
session.run(query)
logging.info(f"Created vector index in {time.time() - start_step:.2f} seconds.")
except Exception as e:
logging.error(f"Failed to create vector index: {e}")
return
except Exception as e:
logging.error("An error occurred while creating the vector index.", exc_info=True)
logging.error(f"Error details: {str(e)}")
def create_fulltext(driver,type):
start_time = time.time()
try:
with driver.session() as session:
try:
start_step = time.time()
if type == "entities":
drop_query = DROP_INDEX_QUERY
elif type == "hybrid":
drop_query = HYBRID_SEARCH_INDEX_DROP_QUERY
else:
drop_query = COMMUNITY_INDEX_DROP_QUERY
session.run(drop_query)
logging.info(f"Dropped existing index (if any) in {time.time() - start_step:.2f} seconds.")
except Exception as e:
logging.error(f"Failed to drop index: {e}")
return
try:
if type == "entities":
start_step = time.time()
result = session.run(LABELS_QUERY)
labels = [record["label"] for record in result]
for label in FILTER_LABELS:
if label in labels:
labels.remove(label)
if labels:
labels_str = ":" + "|".join([f"`{label}`" for label in labels])
logging.info(f"Fetched labels in {time.time() - start_step:.2f} seconds.")
else:
logging.info("Full text index is not created as labels are empty")
return
except Exception as e:
logging.error(f"Failed to fetch labels: {e}")
return
try:
start_step = time.time()
if type == "entities":
fulltext_query = FULL_TEXT_QUERY.format(labels_str=labels_str)
elif type == "hybrid":
fulltext_query = HYBRID_SEARCH_FULL_TEXT_QUERY
else:
fulltext_query = COMMUNITY_INDEX_FULL_TEXT_QUERY
session.run(fulltext_query)
logging.info(f"Created full-text index in {time.time() - start_step:.2f} seconds.")
except Exception as e:
logging.error(f"Failed to create full-text index: {e}")
return
except Exception as e:
logging.error(f"An error occurred during the session: {e}")
finally:
logging.info(f"Process completed in {time.time() - start_time:.2f} seconds.")
def create_vector_fulltext_indexes(uri, username, password, database):
types = ["entities", "hybrid"]
embedding_model = os.getenv('EMBEDDING_MODEL')
embeddings, dimension = load_embedding_model(embedding_model)
if not dimension:
dimension = CHUNK_VECTOR_EMBEDDING_DIMENSION
logging.info("Starting the process of creating full-text indexes.")
try:
driver = GraphDatabase.driver(uri, auth=(username, password), database=database)
driver.verify_connectivity()
logging.info("Database connectivity verified.")
except Exception as e:
logging.error(f"Error connecting to the database: {e}")
return
for index_type in types:
try:
logging.info(f"Creating a full-text index for type '{index_type}'.")
create_fulltext(driver, index_type)
logging.info(f"Full-text index for type '{index_type}' created successfully.")
except Exception as e:
logging.error(f"Failed to create full-text index for type '{index_type}': {e}")
try:
logging.info(f"Creating a vector index for type '{CHUNK_VECTOR_INDEX_NAME}'.")
create_vector_index(driver, CHUNK_VECTOR_INDEX_NAME,dimension)
logging.info("Vector index for chunk created successfully.")
except Exception as e:
logging.error(f"Failed to create vector index for '{CHUNK_VECTOR_INDEX_NAME}': {e}")
try:
driver.close()
logging.info("Driver closed successfully.")
except Exception as e:
logging.error(f"Error closing the driver: {e}")
logging.info("Full-text and vector index creation process completed.")
def create_entity_embedding(graph:Neo4jGraph):
rows = fetch_entities_for_embedding(graph)
for i in range(0, len(rows), 1000):
update_embeddings(rows[i:i+1000],graph)
def fetch_entities_for_embedding(graph):
query = """
MATCH (e)
WHERE NOT (e:Chunk OR e:Document OR e:`__Community__`) AND e.embedding IS NULL AND e.id IS NOT NULL
RETURN elementId(e) AS elementId, e.id + " " + coalesce(e.description, "") AS text
"""
result = graph.query(query)
return [{"elementId": record["elementId"], "text": record["text"]} for record in result]
def update_embeddings(rows, graph):
embedding_model = os.getenv('EMBEDDING_MODEL')
embeddings, dimension = load_embedding_model(embedding_model)
logging.info(f"update embedding for entities")
for row in rows:
row['embedding'] = embeddings.embed_query(row['text'])
query = """
UNWIND $rows AS row
MATCH (e) WHERE elementId(e) = row.elementId
CALL db.create.setNodeVectorProperty(e, "embedding", row.embedding)
"""
return graph.query(query,params={'rows':rows})
def graph_schema_consolidation(graph):
graphDb_data_Access = graphDBdataAccess(graph)
node_labels,relation_labels = graphDb_data_Access.get_nodelabels_relationships()
parser = JsonOutputParser()
prompt = ChatPromptTemplate(
messages=[("system", GRAPH_CLEANUP_PROMPT), ("human", "{input}")],
partial_variables={"format_instructions": parser.get_format_instructions()}
)
graph_cleanup_model = os.getenv("GRAPH_CLEANUP_MODEL", 'openai_gpt_4o')
llm, _ = get_llm(graph_cleanup_model)
chain = prompt | llm | parser
nodes_relations_input = {'nodes': node_labels, 'relationships': relation_labels}
mappings = chain.invoke({'input': nodes_relations_input})
node_mapping = {old: new for new, old_list in mappings['nodes'].items() for old in old_list if new != old}
relation_mapping = {old: new for new, old_list in mappings['relationships'].items() for old in old_list if new != old}
logging.info(f"Node Labels: Total = {len(node_labels)}, Reduced to = {len(set(node_mapping.values()))} (from {len(node_mapping)})")
logging.info(f"Relationship Types: Total = {len(relation_labels)}, Reduced to = {len(set(relation_mapping.values()))} (from {len(relation_mapping)})")
if node_mapping:
for old_label, new_label in node_mapping.items():
query = f"""
MATCH (n:`{old_label}`)
SET n:`{new_label}`
REMOVE n:`{old_label}`
"""
graph.query(query)
for old_label, new_label in relation_mapping.items():
query = f"""
MATCH (n)-[r:`{old_label}`]->(m)
CREATE (n)-[r2:`{new_label}`]->(m)
DELETE r
"""
graph.query(query)
return None