@@ -147,7 +147,6 @@ def get_sources_and_chunks(sources_used, docs):
147
147
result = {
148
148
'sources' : sources_used ,
149
149
'chunkdetails' : chunkdetails_list ,
150
- "entities" : list ()
151
150
}
152
151
return result
153
152
@@ -182,16 +181,19 @@ def format_documents(documents, model):
182
181
sorted_documents = sorted (documents , key = lambda doc : doc .state .get ("query_similarity_score" , 0 ), reverse = True )
183
182
sorted_documents = sorted_documents [:prompt_token_cutoff ]
184
183
185
- formatted_docs = []
184
+ formatted_docs = list ()
186
185
sources = set ()
187
- lc_entities = {'entities' :list ()}
186
+ entities = dict ()
187
+ global_communities = list ()
188
+
188
189
189
190
for doc in sorted_documents :
190
191
try :
191
192
source = doc .metadata .get ('source' , "unknown" )
192
193
sources .add (source )
193
194
194
- lc_entities = doc .metadata if 'entities' in doc .metadata .keys () else lc_entities
195
+ entities = doc .metadata ['entities' ] if 'entities' in doc .metadata .keys () else entities
196
+ global_communities = doc .metadata ["communitydetails" ] if 'communitydetails' in doc .metadata .keys () else global_communities
195
197
196
198
formatted_doc = (
197
199
"Document start\n "
@@ -204,13 +206,13 @@ def format_documents(documents, model):
204
206
except Exception as e :
205
207
logging .error (f"Error formatting document: { e } " )
206
208
207
- return "\n \n " .join (formatted_docs ), sources ,lc_entities
209
+ return "\n \n " .join (formatted_docs ), sources ,entities , global_communities
208
210
209
211
def process_documents (docs , question , messages , llm , model ,chat_mode_settings ):
210
212
start_time = time .time ()
211
213
212
214
try :
213
- formatted_docs , sources ,lc_entities = format_documents (docs , model )
215
+ formatted_docs , sources , entitydetails , communities = format_documents (docs , model )
214
216
215
217
rag_chain = get_rag_chain (llm = llm )
216
218
@@ -219,12 +221,25 @@ def process_documents(docs, question, messages, llm, model,chat_mode_settings):
219
221
"context" : formatted_docs ,
220
222
"input" : question
221
223
})
222
- if chat_mode_settings ["mode" ] == "entity search+vector" :
223
- result = {'sources' : list (),
224
- 'chunkdetails' : list ()}
225
- result .update (lc_entities )
224
+
225
+ result = {'sources' : list (), 'nodedetails' : dict (), 'entities' : dict ()}
226
+ node_details = {"chunkdetails" :list (),"entitydetails" :list (),"communitydetails" :list ()}
227
+ entities = {'entityids' :list (),"relationshipids" :list ()}
228
+
229
+ if chat_mode_settings ["mode" ] == CHAT_ENTITY_VECTOR_MODE :
230
+ node_details ["entitydetails" ] = entitydetails
231
+
232
+ elif chat_mode_settings ["mode" ] == CHAT_GLOBAL_VECTOR_FULLTEXT_MODE :
233
+ node_details ["communitydetails" ] = communities
226
234
else :
227
- result = get_sources_and_chunks (sources , docs )
235
+ sources_and_chunks = get_sources_and_chunks (sources , docs )
236
+ result ['sources' ] = sources_and_chunks ['sources' ]
237
+ node_details ["chunkdetails" ] = sources_and_chunks ["chunkdetails" ]
238
+ entities .update (entitydetails )
239
+
240
+ result ["nodedetails" ] = node_details
241
+ result ["entities" ] = entities
242
+
228
243
content = ai_response .content
229
244
total_tokens = get_total_tokens (ai_response , llm )
230
245
@@ -295,10 +310,13 @@ def create_document_retriever_chain(llm, retriever):
295
310
296
311
def initialize_neo4j_vector (graph , chat_mode_settings ):
297
312
try :
298
- mode = chat_mode_settings .get ('mode' , 'undefined' )
299
313
retrieval_query = chat_mode_settings .get ("retrieval_query" )
300
314
index_name = chat_mode_settings .get ("index_name" )
301
315
keyword_index = chat_mode_settings .get ("keyword_index" , "" )
316
+ node_label = chat_mode_settings .get ("node_label" )
317
+ embedding_node_property = chat_mode_settings .get ("embedding_node_property" )
318
+ text_node_properties = chat_mode_settings .get ("text_node_properties" )
319
+
302
320
303
321
if not retrieval_query or not index_name :
304
322
raise ValueError ("Required settings 'retrieval_query' or 'index_name' are missing." )
@@ -310,28 +328,21 @@ def initialize_neo4j_vector(graph, chat_mode_settings):
310
328
retrieval_query = retrieval_query ,
311
329
graph = graph ,
312
330
search_type = "hybrid" ,
313
- node_label = "Chunk" ,
314
- embedding_node_property = "embedding" ,
315
- text_node_properties = [ "text" ] ,
331
+ node_label = node_label ,
332
+ embedding_node_property = embedding_node_property ,
333
+ text_node_properties = text_node_properties ,
316
334
keyword_index_name = keyword_index
317
335
)
318
336
logging .info (f"Successfully retrieved Neo4jVector Fulltext index '{ index_name } ' and keyword index '{ keyword_index } '" )
319
- elif mode == "entity search+vector" :
320
- neo_db = Neo4jVector .from_existing_index (
321
- embedding = EMBEDDING_FUNCTION ,
322
- index_name = index_name ,
323
- retrieval_query = retrieval_query ,
324
- graph = graph
325
- )
326
337
else :
327
338
neo_db = Neo4jVector .from_existing_graph (
328
339
embedding = EMBEDDING_FUNCTION ,
329
340
index_name = index_name ,
330
341
retrieval_query = retrieval_query ,
331
342
graph = graph ,
332
- node_label = "Chunk" ,
333
- embedding_node_property = "embedding" ,
334
- text_node_properties = [ "text" ]
343
+ node_label = node_label ,
344
+ embedding_node_property = embedding_node_property ,
345
+ text_node_properties = text_node_properties
335
346
)
336
347
logging .info (f"Successfully retrieved Neo4jVector index '{ index_name } '" )
337
348
except Exception as e :
@@ -359,12 +370,12 @@ def create_retriever(neo_db, document_names, chat_mode_settings,search_k, score_
359
370
logging .info (f"Successfully created retriever with search_k={ search_k } , score_threshold={ score_threshold } " )
360
371
return retriever
361
372
362
- def get_neo4j_retriever (graph , document_names ,chat_mode_settings , search_k = CHAT_SEARCH_KWARG_K , score_threshold = CHAT_SEARCH_KWARG_SCORE_THRESHOLD ):
373
+ def get_neo4j_retriever (graph , document_names ,chat_mode_settings , score_threshold = CHAT_SEARCH_KWARG_SCORE_THRESHOLD ):
363
374
try :
364
-
375
+
365
376
neo_db = initialize_neo4j_vector (graph , chat_mode_settings )
366
377
document_names = list (map (str .strip , json .loads (document_names )))
367
- search_k = LOCAL_COMMUNITY_TOP_K if chat_mode_settings ["mode" ] == "entity search+vector" else CHAT_SEARCH_KWARG_K
378
+ search_k = chat_mode_settings ["top_k" ]
368
379
retriever = create_retriever (neo_db , document_names ,chat_mode_settings , search_k , score_threshold )
369
380
return retriever
370
381
except Exception as e :
@@ -397,12 +408,13 @@ def process_chat_response(messages, history, question, model, graph, document_na
397
408
try :
398
409
llm , doc_retriever , model_version = setup_chat (model , graph , document_names , chat_mode_settings )
399
410
400
- docs = retrieve_documents (doc_retriever , messages )
411
+ docs = retrieve_documents (doc_retriever , messages )
412
+
401
413
if docs :
402
414
content , result , total_tokens = process_documents (docs , question , messages , llm , model , chat_mode_settings )
403
415
else :
404
416
content = "I couldn't find any relevant documents to answer your question."
405
- result = {"sources" : [] , "chunkdetails " : [] , "entities" : [] }
417
+ result = {"sources" : list () , "nodedetails " : list () , "entities" : list () }
406
418
total_tokens = 0
407
419
408
420
ai_response = AIMessage (content = content )
@@ -412,18 +424,18 @@ def process_chat_response(messages, history, question, model, graph, document_na
412
424
summarization_thread .start ()
413
425
logging .info ("Summarization thread started." )
414
426
# summarize_and_log(history, messages, llm)
415
-
427
+
416
428
return {
417
429
"session_id" : "" ,
418
430
"message" : content ,
419
431
"info" : {
420
432
"sources" : result ["sources" ],
421
433
"model" : model_version ,
422
- "chunkdetails " : result ["chunkdetails " ],
434
+ "nodedetails " : result ["nodedetails " ],
423
435
"total_tokens" : total_tokens ,
424
436
"response_time" : 0 ,
425
437
"mode" : chat_mode_settings ["mode" ],
426
- "entities" : result ["entities" ]
438
+ "entities" : result ["entities" ],
427
439
},
428
440
"user" : "chatbot"
429
441
}
@@ -435,12 +447,12 @@ def process_chat_response(messages, history, question, model, graph, document_na
435
447
"message" : "Something went wrong" ,
436
448
"info" : {
437
449
"sources" : [],
438
- "chunkdetails " : [],
450
+ "nodedetails " : [],
439
451
"total_tokens" : 0 ,
440
452
"response_time" : 0 ,
441
453
"error" : f"{ type (e ).__name__ } : { str (e )} " ,
442
454
"mode" : chat_mode_settings ["mode" ],
443
- "entities" : []
455
+ "entities" : [],
444
456
},
445
457
"user" : "chatbot"
446
458
}
@@ -593,7 +605,7 @@ def create_neo4j_chat_message_history(graph, session_id, write_access=True):
593
605
raise
594
606
595
607
def get_chat_mode_settings (mode ,settings_map = CHAT_MODE_CONFIG_MAP ):
596
- default_settings = settings_map ["default" ]
608
+ default_settings = settings_map [CHAT_DEFAULT_MODE ]
597
609
try :
598
610
chat_mode_settings = settings_map .get (mode , default_settings )
599
611
chat_mode_settings ["mode" ] = mode
@@ -615,7 +627,7 @@ def QA_RAG(graph,model, question, document_names, session_id, mode, write_access
615
627
user_question = HumanMessage (content = question )
616
628
messages .append (user_question )
617
629
618
- if mode == "graph" :
630
+ if mode == CHAT_GRAPH_MODE :
619
631
result = process_graph_response (model , graph , question , messages , history )
620
632
else :
621
633
chat_mode_settings = get_chat_mode_settings (mode = mode )
0 commit comments