Skip to content

Commit be52f6b

Browse files
Entity embed (#505)
* removed version of pydantic * updating embedding of entities in post processing * schema extraction get llm correction * added the create_entity_embedding task in post processing --------- Co-authored-by: kartikpersistent <[email protected]>
1 parent db9b175 commit be52f6b

File tree

6 files changed

+44
-9
lines changed

6 files changed

+44
-9
lines changed

Diff for: backend/example.env

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ GCS_FILE_CACHE = "" #save the file into GCS or local, SHould be True or False
2424
NEO4J_USER_AGENT=""
2525
ENABLE_USER_AGENT = ""
2626
LLM_MODEL_CONFIG_model_version=""
27+
ENTITY_EMBEDDING="" True or False
2728
#examples
2829
LLM_MODEL_CONFIG_azure-ai-gpt-35="azure_deployment_name,azure_endpoint or base_url,azure_api_key,api_version"
2930
LLM_MODEL_CONFIG_azure-ai-gpt-4o="gpt-4o,https://YOUR-ENDPOINT.openai.azure.com/,azure_api_key,api_version"

Diff for: backend/score.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from src.graphDB_dataAccess import graphDBdataAccess
1919
from src.graph_query import get_graph_results
2020
from src.chunkid_entities import get_entities_from_chunkids
21-
from src.post_processing import create_fulltext
21+
from src.post_processing import create_fulltext, create_entity_embedding
2222
from sse_starlette.sse import EventSourceResponse
2323
import json
2424
from typing import List, Mapping
@@ -262,7 +262,11 @@ async def post_processing(uri=Form(None), userName=Form(None), password=Form(Non
262262
josn_obj = {'api_name': 'post_processing/create_fulltext_index', 'db_url': uri, 'logging_time': formatted_time(datetime.now(timezone.utc))}
263263
logger.log_struct(josn_obj)
264264
logging.info(f'Full Text index created')
265-
265+
if os.environ.get('ENTITY_EMBEDDING').upper()=="TRUE" and "create_entity_embedding" in tasks:
266+
await asyncio.to_thread(create_entity_embedding, graph)
267+
josn_obj = {'api_name': 'post_processing/create_entity_embedding', 'db_url': uri, 'logging_time': formatted_time(datetime.now(timezone.utc))}
268+
logger.log_struct(josn_obj)
269+
logging.info(f'Entity Embeddings created')
266270
return create_api_response('Success', message='All tasks completed successfully')
267271

268272
except Exception as e:

Diff for: backend/src/main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from langchain_community.graphs import Neo4jGraph
22
from src.shared.constants import BUCKET_UPLOAD, PROJECT_ID
3-
from src.shared.schema_extraction import sceham_extraction_from_text
3+
from src.shared.schema_extraction import schema_extraction_from_text
44
from dotenv import load_dotenv
55
from datetime import datetime
66
import logging
@@ -548,5 +548,5 @@ def populate_graph_schema_from_text(text, model, is_schema_description_cheked):
548548
Returns:
549549
data (list): list of lebels and relationTypes
550550
"""
551-
result = sceham_extraction_from_text(text, model, is_schema_description_cheked)
551+
result = schema_extraction_from_text(text, model, is_schema_description_cheked)
552552
return {"labels": result.labels, "relationshipTypes": result.relationshipTypes}

Diff for: backend/src/post_processing.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from neo4j import GraphDatabase
22
import logging
33
import time
4-
4+
from langchain_community.graphs import Neo4jGraph
5+
import os
6+
from src.shared.common_fn import load_embedding_model
57

68
DROP_INDEX_QUERY = "DROP INDEX entities IF EXISTS;"
79
LABELS_QUERY = "CALL db.labels()"
@@ -55,4 +57,32 @@ def create_fulltext(uri, username, password, database):
5557
finally:
5658
driver.close()
5759
logging.info("Driver closed.")
58-
logging.info(f"Process completed in {time.time() - start_time:.2f} seconds.")
60+
logging.info(f"Process completed in {time.time() - start_time:.2f} seconds.")
61+
62+
63+
def create_entity_embedding(graph:Neo4jGraph):
64+
rows = fetch_entities_for_embedding(graph)
65+
for i in range(0, len(rows), 1000):
66+
update_embeddings(rows[i:i+1000],graph)
67+
68+
def fetch_entities_for_embedding(graph):
69+
query = """
70+
MATCH (e)
71+
WHERE NOT (e:Chunk OR e:Document) AND e.embedding IS NULL AND e.id IS NOT NULL
72+
RETURN elementId(e) AS elementId, e.id + " " + coalesce(e.description, "") AS text
73+
"""
74+
result = graph.query(query)
75+
return [{"elementId": record["elementId"], "text": record["text"]} for record in result]
76+
77+
def update_embeddings(rows, graph):
78+
embedding_model = os.getenv('EMBEDDING_MODEL')
79+
embeddings, dimension = load_embedding_model(embedding_model)
80+
logging.info(f"update embedding for entities")
81+
for row in rows:
82+
row['embedding'] = embeddings.embed_query(row['text'])
83+
query = """
84+
UNWIND $rows AS row
85+
MATCH (e) WHERE elementId(e) = row.elementId
86+
CALL db.create.setNodeVectorProperty(e, "embedding", row.embedding)
87+
"""
88+
return graph.query(query,params={'rows':rows})

Diff for: backend/src/shared/schema_extraction.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ class Schema(BaseModel):
2525
"Only return the string types for nodes and relationships, don't return attributes."
2626
)
2727

28-
def sceham_extraction_from_text(input_text:str, model:str, is_schema_description_cheked:bool):
28+
def schema_extraction_from_text(input_text:str, model:str, is_schema_description_cheked:bool):
2929

30-
llm = get_llm(MODEL_VERSIONS[model])
30+
llm, model_name = get_llm(model)
3131
if is_schema_description_cheked:
3232
schema_prompt = PROMPT_TEMPLATE_WITH_SCHEMA
3333
else:

Diff for: frontend/src/utils/Constants.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ export const ChatModeOptions = [
161161
{ Icon: 'abc', value: 'vector' },
162162
];
163163

164-
export const taskParam: string[] = ['update_similarity_graph', 'create_fulltext_index'];
164+
export const taskParam: string[] = ['update_similarity_graph', 'create_fulltext_index','create_entity_embedding'];
165165

166166
export const nvlOptions: NvlOptions = {
167167
allowDynamicMinZoom: true,

0 commit comments

Comments
 (0)