Skip to content

Neo4j message history #273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
- Support for effective_search_ratio parameter in vector and hybrid searches.
- Introduced upsert_vectors utility function for batch upserting embeddings to vector indexes.
- Introduced `extract_cypher` function to enhance Cypher query extraction and formatting in `Text2CypherRetriever`.
- Introduced Neo4jMessageHistory and InMemoryMessageHistory classes for managing LLM message histories.
- Added examples and documentation for using message history with Neo4j and in-memory storage.
- Updated LLM and GraphRAG classes to support new message history classes.

### Changed

Expand Down
9 changes: 9 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,15 @@ Database Interaction
.. autofunction:: neo4j_graphrag.schema.format_schema


***************
Message History
***************

.. autoclass:: neo4j_graphrag.message_history.InMemoryMessageHistory

.. autoclass:: neo4j_graphrag.message_history.Neo4jMessageHistory


******
Errors
******
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ Note that the below example is not the only way you can upsert data into your Ne


.. code:: python

from neo4j import GraphDatabase
from neo4j_graphrag.indexes import upsert_vectors
from neo4j_graphrag.types import EntityType
Expand Down
1 change: 1 addition & 0 deletions docs/source/user_guide_rag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ Populate a Vector Index
==========================

.. code:: python

from random import random

from neo4j import GraphDatabase
Expand Down
3 changes: 2 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ are listed in [the last section of this file](#customize).

- [End to end GraphRAG](./answer/graphrag.py)
- [GraphRAG with message history](./question_answering/graphrag_with_message_history.py)

- [GraphRAG with Neo4j message history](./question_answering/graphrag_with_neo4j_message_history.py)

## Customize

Expand All @@ -75,6 +75,7 @@ are listed in [the last section of this file](#customize).
- [Custom LLM](./customize/llms/custom_llm.py)

- [Message history](./customize/llms/llm_with_message_history.py)
- [Message history with Neo4j](./customize/llms/llm_with_neo4j_message_history.py)
- [System Instruction](./customize/llms/llm_with_system_instructions.py)


Expand Down
7 changes: 4 additions & 3 deletions examples/customize/llms/custom_llm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import random
import string
from typing import Any, Optional
from typing import Any, List, Optional, Union

from neo4j_graphrag.llm import LLMInterface, LLMResponse
from neo4j_graphrag.llm.types import LLMMessage
from neo4j_graphrag.message_history import MessageHistory


class CustomLLM(LLMInterface):
Expand All @@ -15,7 +16,7 @@ def __init__(
def invoke(
self,
input: str,
message_history: Optional[list[LLMMessage]] = None,
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
content: str = (
Expand All @@ -26,7 +27,7 @@ def invoke(
async def ainvoke(
self,
input: str,
message_history: Optional[list[LLMMessage]] = None,
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
raise NotImplementedError()
Expand Down
59 changes: 59 additions & 0 deletions examples/customize/llms/llm_with_neo4j_message_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""This example illustrates the message_history feature
of the LLMInterface by mocking a conversation between a user
and an LLM about Tom Hanks.

Neo4j is used as the database for storing the message history.

OpenAILLM can be replaced by any supported LLM from this package.
"""

import neo4j
from neo4j_graphrag.llm import LLMResponse, OpenAILLM
from neo4j_graphrag.message_history import Neo4jMessageHistory

# Define database credentials
URI = "neo4j+s://demo.neo4jlabs.com"
AUTH = ("recommendations", "recommendations")
DATABASE = "recommendations"
INDEX = "moviePlotsEmbedding"

# set api key here on in the OPENAI_API_KEY env var
api_key = None

llm = OpenAILLM(model_name="gpt-4o", api_key=api_key)

questions = [
"What are some movies Tom Hanks starred in?",
"Is he also a director?",
"Wow, that's impressive. And what about his personal life, does he have children?",
]

driver = neo4j.GraphDatabase.driver(
URI,
auth=AUTH,
database=DATABASE,
)

history = Neo4jMessageHistory(session_id="123", driver=driver, window=10)

for question in questions:
res: LLMResponse = llm.invoke(
question,
message_history=history,
)
history.add_message(
{
"role": "user",
"content": question,
}
)
history.add_message(
{
"role": "assistant",
"content": res.content,
}
)

print("#" * 50, question)
print(res.content)
print("#" * 50)
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""End to end example of building a RAG pipeline backed by a Neo4j database,
simulating a chat with message history which is also stored in Neo4j.

Requires OPENAI_API_KEY to be in the env var.
"""

import neo4j
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
from neo4j_graphrag.generation import GraphRAG
from neo4j_graphrag.llm import OpenAILLM
from neo4j_graphrag.message_history import Neo4jMessageHistory
from neo4j_graphrag.retrievers import VectorCypherRetriever

# Define database credentials
URI = "neo4j+s://demo.neo4jlabs.com"
AUTH = ("recommendations", "recommendations")
DATABASE = "recommendations"
INDEX = "moviePlotsEmbedding"


driver = neo4j.GraphDatabase.driver(
URI,
auth=AUTH,
)

embedder = OpenAIEmbeddings()

retriever = VectorCypherRetriever(
driver,
index_name=INDEX,
retrieval_query="""
WITH node as movie, score
CALL(movie) {
MATCH (movie)<-[:ACTED_IN]-(p:Person)
RETURN collect(p.name) as actors
}
CALL(movie) {
MATCH (movie)<-[:DIRECTED]-(p:Person)
RETURN collect(p.name) as directors
}
RETURN movie.title as title, movie.plot as plot, movie.year as year, actors, directors
""",
embedder=embedder,
neo4j_database=DATABASE,
)

llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})

rag = GraphRAG(
retriever=retriever,
llm=llm,
)

history = Neo4jMessageHistory(session_id="123", driver=driver, window=10)

questions = [
"Who starred in the Apollo 13 movies?",
"Who was its director?",
"In which year was this movie released?",
]

for question in questions:
result = rag.search(
question,
return_context=False,
message_history=history,
)

answer = result.answer
print("#" * 50, question)
print(answer)
print("#" * 50)

history.add_message(
{
"role": "user",
"content": question,
}
)
history.add_message(
{
"role": "assistant",
"content": answer,
}
)

driver.close()
25 changes: 15 additions & 10 deletions src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import logging
import warnings
from typing import Any, Optional
from typing import Any, List, Optional, Union

from pydantic import ValidationError

Expand All @@ -28,6 +28,7 @@
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.llm.types import LLMMessage
from neo4j_graphrag.message_history import MessageHistory
from neo4j_graphrag.retrievers.base import Retriever
from neo4j_graphrag.types import RetrieverResult

Expand Down Expand Up @@ -84,7 +85,7 @@ def __init__(
def search(
self,
query_text: str = "",
message_history: Optional[list[LLMMessage]] = None,
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
examples: str = "",
retriever_config: Optional[dict[str, Any]] = None,
return_context: bool | None = None,
Expand All @@ -102,7 +103,8 @@ def search(

Args:
query_text (str): The user question.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
with each message having a specific role assigned.
examples (str): Examples added to the LLM prompt.
retriever_config (Optional[dict]): Parameters passed to the retriever.
search method; e.g.: top_k
Expand All @@ -127,7 +129,9 @@ def search(
)
except ValidationError as e:
raise SearchValidationError(e.errors())
query = self.build_query(validated_data.query_text, message_history)
if isinstance(message_history, MessageHistory):
message_history = message_history.messages
query = self._build_query(validated_data.query_text, message_history)
retriever_result: RetrieverResult = self.retriever.search(
query_text=query, **validated_data.retriever_config
)
Expand All @@ -147,12 +151,14 @@ def search(
result["retriever_result"] = retriever_result
return RagResultModel(**result)

def build_query(
self, query_text: str, message_history: Optional[list[LLMMessage]] = None
def _build_query(
self,
query_text: str,
message_history: Optional[List[LLMMessage]] = None,
) -> str:
summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words."
if message_history:
summarization_prompt = self.chat_summary_prompt(
summarization_prompt = self._chat_summary_prompt(
message_history=message_history
)
summary = self.llm.invoke(
Expand All @@ -162,10 +168,9 @@ def build_query(
return self.conversation_prompt(summary=summary, current_query=query_text)
return query_text

def chat_summary_prompt(self, message_history: list[LLMMessage]) -> str:
def _chat_summary_prompt(self, message_history: List[LLMMessage]) -> str:
message_list = [
": ".join([f"{value}" for _, value in message.items()])
for message in message_history
f"{message['role']}: {message['content']}" for message in message_history
]
history = "\n".join(message_list)
return f"""
Expand Down
Loading