-
Notifications
You must be signed in to change notification settings - Fork 561
/
Copy pathcommunities.py
512 lines (425 loc) · 19.6 KB
/
communities.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
import logging
from graphdatascience import GraphDataScience
from src.llm import get_llm
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from concurrent.futures import ThreadPoolExecutor, as_completed
import os
from src.shared.common_fn import get_value_from_env_or_secret_manager, load_embedding_model
COMMUNITY_PROJECTION_NAME = "communities"
NODE_PROJECTION = "!Chunk&!Document&!__Community__"
NODE_PROJECTION_ENTITY = "__Entity__"
MAX_WORKERS = 10
MAX_COMMUNITY_LEVELS = 3
MIN_COMMUNITY_SIZE = 1
COMMUNITY_CREATION_DEFAULT_MODEL = "openai_gpt_4o"
CREATE_COMMUNITY_GRAPH_PROJECTION = """
MATCH (source:{node_projection})-[]->(target:{node_projection})
WITH source, target, count(*) as weight
WITH gds.graph.project(
'{project_name}',
source,
target,
{{
relationshipProperties: {{ weight: weight }}
}},
{{undirectedRelationshipTypes: ['*']}}
) AS g
RETURN
g.graphName AS graph_name, g.nodeCount AS nodes, g.relationshipCount AS rels
"""
CREATE_COMMUNITY_CONSTRAINT = "CREATE CONSTRAINT IF NOT EXISTS FOR (c:__Community__) REQUIRE c.id IS UNIQUE;"
CREATE_COMMUNITY_LEVELS = """
MATCH (e:`__Entity__`)
WHERE e.communities is NOT NULL
UNWIND range(0, size(e.communities) - 1 , 1) AS index
CALL {
WITH e, index
WITH e, index
WHERE index = 0
MERGE (c:`__Community__` {id: toString(index) + '-' + toString(e.communities[index])})
ON CREATE SET c.level = index
MERGE (e)-[:IN_COMMUNITY]->(c)
RETURN count(*) AS count_0
}
CALL {
WITH e, index
WITH e, index
WHERE index > 0
MERGE (current:`__Community__` {id: toString(index) + '-' + toString(e.communities[index])})
ON CREATE SET current.level = index
MERGE (previous:`__Community__` {id: toString(index - 1) + '-' + toString(e.communities[index - 1])})
ON CREATE SET previous.level = index - 1
MERGE (previous)-[:PARENT_COMMUNITY]->(current)
RETURN count(*) AS count_1
}
RETURN count(*)
"""
CREATE_COMMUNITY_RANKS = """
MATCH (c:__Community__)<-[:IN_COMMUNITY*]-(:!Chunk&!Document&!__Community__)<-[HAS_ENTITY]-(:Chunk)<-[]-(d:Document)
WITH c, count(distinct d) AS rank
SET c.community_rank = rank;
"""
CREATE_PARENT_COMMUNITY_RANKS = """
MATCH (c:__Community__)<-[:PARENT_COMMUNITY*]-(:__Community__)<-[:IN_COMMUNITY*]-(:!Chunk&!Document&!__Community__)<-[HAS_ENTITY]-(:Chunk)<-[]-(d:Document)
WITH c, count(distinct d) AS rank
SET c.community_rank = rank;
"""
CREATE_COMMUNITY_WEIGHTS = """
MATCH (n:`__Community__`)<-[:IN_COMMUNITY]-()<-[:HAS_ENTITY]-(c)
WITH n, count(distinct c) AS chunkCount
SET n.weight = chunkCount
"""
CREATE_PARENT_COMMUNITY_WEIGHTS = """
MATCH (n:`__Community__`)<-[:PARENT_COMMUNITY*]-(:`__Community__`)<-[:IN_COMMUNITY]-()<-[:HAS_ENTITY]-(c)
WITH n, count(distinct c) AS chunkCount
SET n.weight = chunkCount
"""
GET_COMMUNITY_INFO = """
MATCH (c:`__Community__`)<-[:IN_COMMUNITY]-(e)
WHERE c.level = 0
WITH c, collect(e) AS nodes
WHERE size(nodes) > 1
CALL apoc.path.subgraphAll(nodes[0], {
whitelistNodes:nodes
})
YIELD relationships
RETURN c.id AS communityId,
[n in nodes | {id: n.id, description: n.description, type: [el in labels(n) WHERE el <> '__Entity__'][0]}] AS nodes,
[r in relationships | {start: startNode(r).id, type: type(r), end: endNode(r).id}] AS rels
"""
GET_PARENT_COMMUNITY_INFO = """
MATCH (p:`__Community__`)<-[:PARENT_COMMUNITY*]-(c:`__Community__`)
WHERE p.summary is null and c.summary is not null
RETURN p.id as communityId, collect(c.summary) as texts
"""
STORE_COMMUNITY_SUMMARIES = """
UNWIND $data AS row
MERGE (c:__Community__ {id:row.community})
SET c.summary = row.summary,
c.title = row.title
"""
COMMUNITY_SYSTEM_TEMPLATE = "Given input triples, generate the information summary. No pre-amble."
COMMUNITY_TEMPLATE = """
Based on the provided nodes and relationships that belong to the same graph community,
generate following output in exact format
title: A concise title, no more than 4 words,
summary: A natural language summary of the information
{community_info}
Example output:
title: Example Title,
summary: This is an example summary that describes the key information of this community.
"""
PARENT_COMMUNITY_SYSTEM_TEMPLATE = "Given an input list of community summaries, generate a summary of the information"
PARENT_COMMUNITY_TEMPLATE = """Based on the provided list of community summaries that belong to the same graph community,
generate following output in exact format
title: A concise title, no more than 4 words,
summary: A natural language summary of the information. Include all the necessary information as much as possible.
{community_info}
Example output:
title: Example Title,
summary: This is an example summary that describes the key information of this community.
"""
GET_COMMUNITY_DETAILS = """
MATCH (c:`__Community__`)
WHERE c.embedding IS NULL AND c.summary IS NOT NULL
RETURN c.id as communityId, c.summary as text
"""
WRITE_COMMUNITY_EMBEDDINGS = """
UNWIND $rows AS row
MATCH (c) WHERE c.id = row.communityId
CALL db.create.setNodeVectorProperty(c, "embedding", row.embedding)
"""
DROP_COMMUNITIES = "MATCH (c:`__Community__`) DETACH DELETE c"
DROP_COMMUNITY_PROPERTY = "MATCH (e:`__Entity__`) REMOVE e.communities"
ENTITY_VECTOR_INDEX_NAME = "entity_vector"
ENTITY_VECTOR_EMBEDDING_DIMENSION = 384
DROP_ENTITY_VECTOR_INDEX_QUERY = f"DROP INDEX {ENTITY_VECTOR_INDEX_NAME} IF EXISTS;"
CREATE_ENTITY_VECTOR_INDEX_QUERY = """
CREATE VECTOR INDEX {index_name} IF NOT EXISTS FOR (e:__Entity__) ON e.embedding
OPTIONS {{
indexConfig: {{
`vector.dimensions`: {embedding_dimension},
`vector.similarity_function`: 'cosine'
}}
}}
"""
COMMUNITY_VECTOR_INDEX_NAME = "community_vector"
COMMUNITY_VECTOR_EMBEDDING_DIMENSION = 384
DROP_COMMUNITY_VECTOR_INDEX_QUERY = f"DROP INDEX {COMMUNITY_VECTOR_INDEX_NAME} IF EXISTS;"
CREATE_COMMUNITY_VECTOR_INDEX_QUERY = """
CREATE VECTOR INDEX {index_name} IF NOT EXISTS FOR (c:__Community__) ON c.embedding
OPTIONS {{
indexConfig: {{
`vector.dimensions`: {embedding_dimension},
`vector.similarity_function`: 'cosine'
}}
}}
"""
COMMUNITY_FULLTEXT_INDEX_NAME = "community_keyword"
COMMUNITY_FULLTEXT_INDEX_DROP_QUERY = f"DROP INDEX {COMMUNITY_FULLTEXT_INDEX_NAME} IF EXISTS;"
COMMUNITY_INDEX_FULL_TEXT_QUERY = f"CREATE FULLTEXT INDEX {COMMUNITY_FULLTEXT_INDEX_NAME} FOR (n:`__Community__`) ON EACH [n.summary]"
def get_gds_driver(uri, username, password, database):
try:
if all(v is None for v in [username, password]):
username= os.getenv('NEO4J_USERNAME')
database= os.getenv('NEO4J_DATABASE')
password= os.getenv('NEO4J_PASSWORD')
gds = GraphDataScience(
endpoint=uri,
auth=(username, password),
database=database
)
logging.info("Successfully created GDS driver.")
return gds
except Exception as e:
logging.error(f"Failed to create GDS driver: {e}")
raise
def create_community_graph_projection(gds, project_name=COMMUNITY_PROJECTION_NAME, node_projection=NODE_PROJECTION):
try:
existing_projects = gds.graph.list()
project_exists = existing_projects["graphName"].str.contains(project_name, regex=False).any()
if project_exists:
logging.info(f"Projection '{project_name}' already exists. Dropping it.")
gds.graph.drop(project_name)
logging.info(f"Creating new graph project '{project_name}'.")
projection_query = CREATE_COMMUNITY_GRAPH_PROJECTION.format(node_projection=node_projection,project_name=project_name)
graph_projection_result = gds.run_cypher(projection_query)
projection_result = graph_projection_result.to_dict(orient="records")[0]
logging.info(f"Graph projection '{projection_result['graph_name']}' created successfully with {projection_result['nodes']} nodes and {projection_result['rels']} relationships.")
graph_project = gds.graph.get(projection_result['graph_name'])
return graph_project
except Exception as e:
logging.error(f"Failed to create community graph project: {e}")
raise
def write_communities(gds, graph_project, project_name=COMMUNITY_PROJECTION_NAME):
try:
logging.info(f"Writing communities to the graph project '{project_name}'.")
gds.leiden.write(
graph_project,
writeProperty=project_name,
includeIntermediateCommunities=True,
relationshipWeightProperty="weight",
maxLevels=MAX_COMMUNITY_LEVELS,
minCommunitySize=MIN_COMMUNITY_SIZE,
)
logging.info("Communities written successfully.")
return True
except Exception as e:
logging.error(f"Failed to write communities: {e}")
return False
def get_community_chain(model, is_parent=False,community_template=COMMUNITY_TEMPLATE,system_template=COMMUNITY_SYSTEM_TEMPLATE):
try:
if is_parent:
community_template=PARENT_COMMUNITY_TEMPLATE
system_template= PARENT_COMMUNITY_SYSTEM_TEMPLATE
llm, model_name = get_llm(model)
community_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
system_template,
),
("human", community_template),
]
)
community_chain = community_prompt | llm | StrOutputParser()
return community_chain
except Exception as e:
logging.error(f"Failed to create community chain: {e}")
raise
def prepare_string(community_data):
try:
nodes_description = "Nodes are:\n"
for node in community_data['nodes']:
node_id = node['id']
node_type = node['type']
node_description = f", description: {node['description']}" if 'description' in node and node['description'] else ""
nodes_description += f"id: {node_id}, type: {node_type}{node_description}\n"
relationships_description = "Relationships are:\n"
for rel in community_data['rels']:
start_node = rel['start']
end_node = rel['end']
relationship_type = rel['type']
relationship_description = f", description: {rel['description']}" if 'description' in rel and rel['description'] else ""
relationships_description += f"({start_node})-[:{relationship_type}]->({end_node}){relationship_description}\n"
return nodes_description + "\n" + relationships_description
except Exception as e:
logging.error(f"Failed to prepare string from community data: {e}")
raise
def process_community_info(community, chain, is_parent=False):
try:
if is_parent:
combined_text = " ".join(f"Summary {i+1}: {summary}" for i, summary in enumerate(community.get("texts", [])))
else:
combined_text = prepare_string(community)
summary_response = chain.invoke({'community_info': combined_text})
lines = summary_response.splitlines()
title = "Untitled Community"
summary = ""
for line in lines:
if line.lower().startswith("title"):
title = line.split(":", 1)[-1].strip()
elif line.lower().startswith("summary"):
summary = line.split(":", 1)[-1].strip()
logging.info(f"Community Title : {title}")
return {"community": community['communityId'], "title":title, "summary": summary}
except Exception as e:
logging.error(f"Failed to process community {community.get('communityId', 'unknown')}: {e}")
return None
def create_community_summaries(gds, model):
try:
community_info_list = gds.run_cypher(GET_COMMUNITY_INFO)
community_chain = get_community_chain(model)
summaries = []
with ThreadPoolExecutor() as executor:
futures = [executor.submit(process_community_info, community, community_chain) for community in community_info_list.to_dict(orient="records")]
for future in as_completed(futures):
result = future.result()
if result:
summaries.append(result)
else:
logging.error("community summaries could not be processed.")
gds.run_cypher(STORE_COMMUNITY_SUMMARIES, params={"data": summaries})
parent_community_info = gds.run_cypher(GET_PARENT_COMMUNITY_INFO)
parent_community_chain = get_community_chain(model, is_parent=True)
parent_summaries = []
with ThreadPoolExecutor() as executor:
futures = [executor.submit(process_community_info, community, parent_community_chain, is_parent=True) for community in parent_community_info.to_dict(orient="records")]
for future in as_completed(futures):
result = future.result()
if result:
parent_summaries.append(result)
else:
logging.error("parent community summaries could not be processed.")
gds.run_cypher(STORE_COMMUNITY_SUMMARIES, params={"data": parent_summaries})
except Exception as e:
logging.error(f"Failed to create community summaries: {e}")
raise
def create_community_embeddings(gds):
try:
embedding_model = get_value_from_env_or_secret_manager("EMBEDDING_MODEL")
embeddings, dimension = load_embedding_model(embedding_model)
logging.info(f"Embedding model '{embedding_model}' loaded successfully.")
logging.info("Fetching community details.")
rows = gds.run_cypher(GET_COMMUNITY_DETAILS)
rows = rows[['communityId', 'text']].to_dict(orient='records')
logging.info(f"Fetched {len(rows)} communities.")
batch_size = 100
for i in range(0, len(rows), batch_size):
batch_rows = rows[i:i+batch_size]
for row in batch_rows:
try:
row['embedding'] = embeddings.embed_query(row['text'])
except Exception as e:
logging.error(f"Failed to embed text for community ID {row['communityId']}: {e}")
row['embedding'] = None
try:
logging.info("Writing embeddings to the database.")
gds.run_cypher(WRITE_COMMUNITY_EMBEDDINGS, params={'rows': batch_rows})
logging.info("Embeddings written successfully.")
except Exception as e:
logging.error(f"Failed to write embeddings to the database: {e}")
continue
return dimension
except Exception as e:
logging.error(f"An error occurred during the community embedding process: {e}")
def create_vector_index(gds, index_type,embedding_dimension=None):
drop_query = ""
query = ""
if index_type == ENTITY_VECTOR_INDEX_NAME:
drop_query = DROP_ENTITY_VECTOR_INDEX_QUERY
query = CREATE_ENTITY_VECTOR_INDEX_QUERY.format(
index_name=ENTITY_VECTOR_INDEX_NAME,
embedding_dimension=embedding_dimension if embedding_dimension else ENTITY_VECTOR_EMBEDDING_DIMENSION
)
elif index_type == COMMUNITY_VECTOR_INDEX_NAME:
drop_query = DROP_COMMUNITY_VECTOR_INDEX_QUERY
query = CREATE_COMMUNITY_VECTOR_INDEX_QUERY.format(
index_name=COMMUNITY_VECTOR_INDEX_NAME,
embedding_dimension=embedding_dimension if embedding_dimension else COMMUNITY_VECTOR_EMBEDDING_DIMENSION
)
else:
logging.error(f"Invalid index type provided: {index_type}")
return
try:
logging.info("Starting the process to create vector index.")
logging.info(f"Executing drop query: {drop_query}")
gds.run_cypher(drop_query)
logging.info(f"Executing create query: {query}")
gds.run_cypher(query)
logging.info(f"Vector index '{index_type}' created successfully.")
except Exception as e:
logging.error("An error occurred while creating the vector index.", exc_info=True)
logging.error(f"Error details: {str(e)}")
def create_fulltext_index(gds, index_type):
drop_query = ""
query = ""
if index_type == COMMUNITY_FULLTEXT_INDEX_NAME:
drop_query = COMMUNITY_FULLTEXT_INDEX_DROP_QUERY
query = COMMUNITY_INDEX_FULL_TEXT_QUERY
else:
logging.error(f"Invalid index type provided: {index_type}")
return
try:
logging.info("Starting the process to create full-text index.")
logging.info(f"Executing drop query: {drop_query}")
gds.run_cypher(drop_query)
logging.info(f"Executing create query: {query}")
gds.run_cypher(query)
logging.info(f"Full-text index '{index_type}' created successfully.")
except Exception as e:
logging.error("An error occurred while creating the full-text index.", exc_info=True)
logging.error(f"Error details: {str(e)}")
def create_community_properties(gds, model):
commands = [
(CREATE_COMMUNITY_CONSTRAINT, "created community constraint to the graph."),
(CREATE_COMMUNITY_LEVELS, "Successfully created community levels."),
(CREATE_COMMUNITY_RANKS, "Successfully created community ranks."),
(CREATE_PARENT_COMMUNITY_RANKS, "Successfully created parent community ranks."),
(CREATE_COMMUNITY_WEIGHTS, "Successfully created community weights."),
(CREATE_PARENT_COMMUNITY_WEIGHTS, "Successfully created parent community weights."),
]
try:
for command, message in commands:
gds.run_cypher(command)
logging.info(message)
create_community_summaries(gds, model)
logging.info("Successfully created community summaries.")
embedding_dimension = create_community_embeddings(gds)
logging.info("Successfully created community embeddings.")
create_vector_index(gds=gds,index_type=ENTITY_VECTOR_INDEX_NAME,embedding_dimension=embedding_dimension)
logging.info("Successfully created Entity Vector Index.")
create_vector_index(gds=gds,index_type=COMMUNITY_VECTOR_INDEX_NAME,embedding_dimension=embedding_dimension)
logging.info("Successfully created community Vector Index.")
create_fulltext_index(gds=gds,index_type=COMMUNITY_FULLTEXT_INDEX_NAME)
logging.info("Successfully created community fulltext Index.")
except Exception as e:
logging.error(f"Error during community properties creation: {e}")
raise
def clear_communities(gds):
try:
logging.info("Starting to clear communities.")
logging.info("Dropping communities...")
gds.run_cypher(DROP_COMMUNITIES)
logging.info(f"Communities dropped successfully")
logging.info("Dropping community property from entities...")
gds.run_cypher(DROP_COMMUNITY_PROPERTY)
logging.info(f"Community property dropped successfully")
except Exception as e:
logging.error(f"An error occurred while clearing communities: {e}")
raise
def create_communities(uri, username, password, database,model=COMMUNITY_CREATION_DEFAULT_MODEL):
try:
gds = get_gds_driver(uri, username, password, database)
clear_communities(gds)
graph_project = create_community_graph_projection(gds)
write_communities_sucess = write_communities(gds, graph_project)
if write_communities_sucess:
logging.info("Starting Community properties creation process.")
create_community_properties(gds,model)
logging.info("Communities creation process completed successfully.")
else:
logging.warning("Failed to write communities. Constraint was not applied.")
except Exception as e:
logging.error(f"Failed to create communities: {e}")