From eac9f094734fc4c47468ab33c4d1a1aa2956c597 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Mon, 17 Feb 2025 17:21:33 +0000 Subject: [PATCH 01/18] Added message history classes --- src/neo4j_graphrag/message_history.py | 145 ++++++++++++++++++++++++++ src/neo4j_graphrag/types.py | 9 +- tests/e2e/test_message_history_e2e.py | 88 ++++++++++++++++ tests/unit/test_message_history.py | 93 +++++++++++++++++ 4 files changed, 334 insertions(+), 1 deletion(-) create mode 100644 src/neo4j_graphrag/message_history.py create mode 100644 tests/e2e/test_message_history_e2e.py create mode 100644 tests/unit/test_message_history.py diff --git a/src/neo4j_graphrag/message_history.py b/src/neo4j_graphrag/message_history.py new file mode 100644 index 00000000..4c7d3650 --- /dev/null +++ b/src/neo4j_graphrag/message_history.py @@ -0,0 +1,145 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Union + +import neo4j +from pydantic import PositiveInt + +from neo4j_graphrag.llm.types import ( + LLMMessage, +) +from neo4j_graphrag.types import ( + Neo4jDriverModel, + Neo4jMessageHistoryModel, +) + +CREATE_SESSION_NODE_QUERY = "MERGE (s:`{node_label}` {{id:$session_id}})" + +CLEAR_SESSION_QUERY = ( + "MATCH (s:`{node_label}`)-[:LAST_MESSAGE]->(last_message) " + "WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT]-() " + "WITH p, length(p) AS length ORDER BY length DESC LIMIT 1 " + "UNWIND nodes(p) as node DETACH DELETE node;" +) + +GET_MESSAGES_QUERY = ( + "MATCH (s:`{node_label}`)-[:LAST_MESSAGE]->(last_message) " + "WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT*0.." + "{window}]-() WITH p, length(p) AS length " + "ORDER BY length DESC LIMIT 1 UNWIND reverse(nodes(p)) AS node " + "RETURN {{data:{{content: node.content}}, role:node.role}} AS result" +) + +ADD_MESSAGE_QUERY = ( + "MATCH (s:`{node_label}`) WHERE s.id = $session_id " + "OPTIONAL MATCH (s)-[lm:LAST_MESSAGE]->(last_message) " + "CREATE (s)-[:LAST_MESSAGE]->(new:Message) " + "SET new += {{role:$role, content:$content}} " + "WITH new, lm, last_message WHERE last_message IS NOT NULL " + "CREATE (last_message)-[:NEXT]->(new) " + "DELETE lm" +) + + +class MessageHistory(ABC): + @property + @abstractmethod + def messages(self) -> List[LLMMessage]: ... + + @abstractmethod + def add_message(self, message: LLMMessage) -> None: ... + + def add_messages(self, messages: List[LLMMessage]) -> None: + for message in messages: + self.add_message(message) + + @abstractmethod + def clear(self) -> None: ... + + +class InMemoryMessageHistory(MessageHistory): + def __init__(self, messages: List[LLMMessage] = []) -> None: + self._messages = messages + + @property + def messages(self) -> List[LLMMessage]: + return self._messages + + def add_message(self, message: LLMMessage) -> None: + self._messages.append(message) + + def add_messages(self, messages: List[LLMMessage]) -> None: + self._messages.extend(messages) + + def clear(self) -> None: + self._messages = [] + + +class Neo4jMessageHistory(MessageHistory): + def __init__( + self, + session_id: Union[str, int], + driver: neo4j.Driver, + node_label: str = "Session", + window: Optional[PositiveInt] = None, + ) -> None: + validated_data = Neo4jMessageHistoryModel( + session_id=session_id, + driver_model=Neo4jDriverModel(driver=driver), + node_label=node_label, + window=window, + ) + self._driver = validated_data.driver_model.driver + self._session_id = validated_data.session_id + self._node_label = validated_data.node_label + self._window = ( + "" if validated_data.window is None else validated_data.window - 1 + ) + # Create session node + self._driver.execute_query( + query_=CREATE_SESSION_NODE_QUERY.format(node_label=self._node_label), + parameters_={"session_id": self._session_id}, + ) + + @property + def messages(self) -> List[LLMMessage]: + result = self._driver.execute_query( + query_=GET_MESSAGES_QUERY.format( + node_label=self._node_label, window=self._window + ), + parameters_={"session_id": self._session_id}, + ) + messages = [ + LLMMessage( + content=el["result"]["data"]["content"], + role=el["result"]["role"], + ) + for el in result.records + ] + return messages + + @messages.setter + def messages(self, messages: List[LLMMessage]) -> None: + raise NotImplementedError( + "Direct assignment to 'messages' is not allowed." + " Use the 'add_messages' instead." + ) + + def add_message(self, message: LLMMessage) -> None: + self._driver.execute_query( + query_=ADD_MESSAGE_QUERY.format(node_label=self._node_label), + parameters_={ + "role": message["role"], + "content": message["content"], + "session_id": self._session_id, + }, + ) + + def clear(self) -> None: + self._driver.execute_query( + query_=CLEAR_SESSION_QUERY.format(node_label=self._node_label), + parameters_={"session_id": self._session_id}, + ) + + def __del__(self) -> None: + if self._driver: + self._driver.close() diff --git a/src/neo4j_graphrag/types.py b/src/neo4j_graphrag/types.py index 3f147296..0e705ee5 100644 --- a/src/neo4j_graphrag/types.py +++ b/src/neo4j_graphrag/types.py @@ -15,7 +15,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Callable, Literal, Optional +from typing import Any, Callable, Literal, Optional, Union import neo4j from pydantic import ( @@ -251,3 +251,10 @@ class Text2CypherRetrieverModel(BaseModel): result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None custom_prompt: Optional[str] = None neo4j_database: Optional[str] = None + + +class Neo4jMessageHistoryModel(BaseModel): + session_id: Union[str, int] + driver_model: Neo4jDriverModel + node_label: str = "Session" + window: Optional[PositiveInt] = None diff --git a/tests/e2e/test_message_history_e2e.py b/tests/e2e/test_message_history_e2e.py new file mode 100644 index 00000000..f6413bda --- /dev/null +++ b/tests/e2e/test_message_history_e2e.py @@ -0,0 +1,88 @@ +import neo4j +from neo4j_graphrag.llm.types import LLMMessage +from neo4j_graphrag.message_history import Neo4jMessageHistory + + +def test_neo4j_message_history_add_message(driver: neo4j.Driver) -> None: + driver.execute_query(query_="MATCH (n) DETACH DELETE n;") + message_history = Neo4jMessageHistory(session_id="123", driver=driver) + message_history.add_message( + LLMMessage(role="user", content="Hello"), + ) + assert len(message_history.messages) == 1 + assert message_history.messages[0]["role"] == "user" + assert message_history.messages[0]["content"] == "Hello" + + +def test_neo4j_message_history_add_messages(driver: neo4j.Driver) -> None: + driver.execute_query(query_="MATCH (n) DETACH DELETE n;") + message_history = Neo4jMessageHistory(session_id="123", driver=driver) + message_history.add_messages( + [ + LLMMessage(role="system", content="You are a helpful assistant."), + LLMMessage(role="user", content="Hello"), + LLMMessage( + role="assistant", + content="Hello, how may I help you today?", + ), + LLMMessage(role="user", content="I'd like to buy a new car."), + LLMMessage( + role="assistant", + content="I'd be happy to help you find the perfect car.", + ), + ] + ) + assert len(message_history.messages) == 5 + assert message_history.messages[0]["role"] == "system" + assert message_history.messages[0]["content"] == "You are a helpful assistant." + assert message_history.messages[1]["role"] == "user" + assert message_history.messages[1]["content"] == "Hello" + assert message_history.messages[2]["role"] == "assistant" + assert message_history.messages[2]["content"] == "Hello, how may I help you today?" + assert message_history.messages[3]["role"] == "user" + assert message_history.messages[3]["content"] == "I'd like to buy a new car." + assert message_history.messages[4]["role"] == "assistant" + assert ( + message_history.messages[4]["content"] + == "I'd be happy to help you find the perfect car." + ) + + +def test_neo4j_message_history_clear(driver: neo4j.Driver) -> None: + driver.execute_query(query_="MATCH (n) DETACH DELETE n;") + message_history = Neo4jMessageHistory(session_id="123", driver=driver) + message_history.add_messages( + [ + LLMMessage(role="system", content="You are a helpful assistant."), + LLMMessage(role="user", content="Hello"), + ] + ) + assert len(message_history.messages) == 2 + message_history.clear() + assert len(message_history.messages) == 0 + + +def test_neo4j_message_window_size(driver: neo4j.Driver) -> None: + driver.execute_query(query_="MATCH (n) DETACH DELETE n;") + message_history = Neo4jMessageHistory(session_id="123", driver=driver, window=1) + message_history.add_messages( + [ + LLMMessage(role="system", content="You are a helpful assistant."), + LLMMessage(role="user", content="Hello"), + LLMMessage( + role="assistant", + content="Hello, how may I help you today?", + ), + LLMMessage(role="user", content="I'd like to buy a new car."), + LLMMessage( + role="assistant", + content="I'd be happy to help you find the perfect car.", + ), + ] + ) + assert len(message_history.messages) == 1 + assert ( + message_history.messages[0]["content"] + == "I'd be happy to help you find the perfect car." + ) + assert message_history.messages[0]["role"] == "assistant" diff --git a/tests/unit/test_message_history.py b/tests/unit/test_message_history.py new file mode 100644 index 00000000..533235ce --- /dev/null +++ b/tests/unit/test_message_history.py @@ -0,0 +1,93 @@ +from unittest.mock import MagicMock + +import pytest +from neo4j_graphrag.llm.types import LLMMessage +from neo4j_graphrag.message_history import InMemoryMessageHistory, Neo4jMessageHistory +from pydantic import ValidationError + + +def test_in_memory_message_history_add_message() -> None: + message_history = InMemoryMessageHistory() + message_history.add_message( + LLMMessage(role="user", content="may thy knife chip and shatter") + ) + assert len(message_history.messages) == 1 + assert message_history.messages[0]["role"] == "user" + assert message_history.messages[0]["content"] == "may thy knife chip and shatter" + + +def test_in_memory_message_history_add_messages() -> None: + message_history = InMemoryMessageHistory() + message_history.add_messages( + [ + LLMMessage(role="user", content="may thy knife chip and shatter"), + LLMMessage( + role="assistant", + content="He who controls the spice controls the universe.", + ), + ] + ) + assert len(message_history.messages) == 2 + assert message_history.messages[0]["role"] == "user" + assert message_history.messages[0]["content"] == "may thy knife chip and shatter" + assert message_history.messages[1]["role"] == "assistant" + assert ( + message_history.messages[1]["content"] + == "He who controls the spice controls the universe." + ) + + +def test_in_memory_message_history_clear() -> None: + message_history = InMemoryMessageHistory() + message_history.add_messages( + [ + LLMMessage(role="user", content="may thy knife chip and shatter"), + LLMMessage( + role="assistant", + content="He who controls the spice controls the universe.", + ), + ] + ) + assert len(message_history.messages) == 2 + message_history.clear() + assert len(message_history.messages) == 0 + + +def test_neo4j_message_history_invalid_session_id(driver: MagicMock) -> None: + with pytest.raises(ValidationError) as exc_info: + Neo4jMessageHistory(session_id=1.5, driver=driver, node_label="123", window=1) # type: ignore[arg-type] + assert "Input should be a valid string" in str(exc_info.value) + + +def test_neo4j_message_history_invalid_driver() -> None: + with pytest.raises(ValidationError) as exc_info: + Neo4jMessageHistory(session_id="123", driver=1.5, node_label="123", window=1) # type: ignore[arg-type] + assert "Input should be a valid dictionary or instance of Neo4jDriver" in str( + exc_info.value + ) + + +def test_neo4j_message_history_invalid_node_label(driver: MagicMock) -> None: + with pytest.raises(ValidationError) as exc_info: + Neo4jMessageHistory(session_id="123", driver=driver, node_label=1.5, window=1) # type: ignore[arg-type] + assert "Input should be a valid string" in str(exc_info.value) + + +def test_neo4j_message_history_invalid_window(driver: MagicMock) -> None: + with pytest.raises(ValidationError) as exc_info: + Neo4jMessageHistory( + session_id="123", driver=driver, node_label="123", window=-1 + ) + assert "Input should be greater than 0" in str(exc_info.value) + + +def test_neo4j_message_history_messages_setter(neo4j_driver: MagicMock) -> None: + message_history = Neo4jMessageHistory(session_id="123", driver=neo4j_driver) + with pytest.raises(NotImplementedError) as exc_info: + message_history.messages = [ + LLMMessage(role="user", content="may thy knife chip and shatter"), + ] + assert ( + str(exc_info.value) + == "Direct assignment to 'messages' is not allowed. Use the 'add_messages' instead." + ) From 0074b1ad969274823520ef036ba68a2c86d3d615 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Mon, 17 Feb 2025 17:26:49 +0000 Subject: [PATCH 02/18] Updated Neo4jMessageHistoryModel --- src/neo4j_graphrag/types.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/neo4j_graphrag/types.py b/src/neo4j_graphrag/types.py index 0e705ee5..74cc35d7 100644 --- a/src/neo4j_graphrag/types.py +++ b/src/neo4j_graphrag/types.py @@ -258,3 +258,15 @@ class Neo4jMessageHistoryModel(BaseModel): driver_model: Neo4jDriverModel node_label: str = "Session" window: Optional[PositiveInt] = None + + @field_validator("session_id") + def validate_session_id(cls, v: Union[str, int]) -> Union[str, int]: + if isinstance(v, str) and len(v) == 0: + raise ValueError("session_id cannot be empty") + return v + + @field_validator("node_label") + def validate_node_label(cls, v: str) -> str: + if len(v) == 0: + raise ValueError("node_label cannot be empty") + return v From 2bb29d206b4172289438afbded0cf438f81ac479 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Mon, 17 Feb 2025 18:00:16 +0000 Subject: [PATCH 03/18] Fixed spelling error --- src/neo4j_graphrag/llm/anthropic_llm.py | 4 ++-- src/neo4j_graphrag/llm/base.py | 4 ++-- src/neo4j_graphrag/llm/cohere_llm.py | 4 ++-- src/neo4j_graphrag/llm/mistralai_llm.py | 4 ++-- src/neo4j_graphrag/llm/ollama_llm.py | 4 ++-- src/neo4j_graphrag/llm/openai_llm.py | 4 ++-- src/neo4j_graphrag/llm/vertexai_llm.py | 4 ++-- tests/unit/llm/test_mistralai_llm.py | 2 +- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index b4d3d8ff..e219a521 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -99,7 +99,7 @@ def invoke( Args: input (str): The text to send to the LLM. message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invokation. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. @@ -132,7 +132,7 @@ async def ainvoke( Args: input (str): The text to send to the LLM. message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invokation. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index 49f1afc3..25ebad73 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -53,7 +53,7 @@ def invoke( Args: input (str): Text sent to the LLM. message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invokation. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. @@ -74,7 +74,7 @@ async def ainvoke( Args: input (str): Text sent to the LLM. message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invokation. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index 9621cae2..439c94fc 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -104,7 +104,7 @@ def invoke( Args: input (str): The text to send to the LLM. message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invokation. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. @@ -132,7 +132,7 @@ async def ainvoke( Args: input (str): The text to send to the LLM. message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invokation. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index afd1447e..7b603653 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -95,7 +95,7 @@ def invoke( Args: input (str): Text sent to the LLM. message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invokation. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from MistralAI. @@ -131,7 +131,7 @@ async def ainvoke( Args: input (str): Text sent to the LLM. message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invokation. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from MistralAI. diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index 8f1e8193..572e56bd 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -86,7 +86,7 @@ def invoke( Args: input (str): The text to send to the LLM. message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invokation. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. @@ -114,7 +114,7 @@ async def ainvoke( Args: input (str): Text sent to the LLM. message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invokation. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from OpenAI. diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 21a49a15..38ec27a0 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -95,7 +95,7 @@ def invoke( Args: input (str): Text sent to the LLM. message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invokation. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from OpenAI. @@ -126,7 +126,7 @@ async def ainvoke( Args: input (str): Text sent to the LLM. message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invokation. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from OpenAI. diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index a465e553..2b79264e 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -114,7 +114,7 @@ def invoke( Args: input (str): The text to send to the LLM. message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invokation. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. @@ -143,7 +143,7 @@ async def ainvoke( Args: input (str): The text to send to the LLM. message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invokation. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. diff --git a/tests/unit/llm/test_mistralai_llm.py b/tests/unit/llm/test_mistralai_llm.py index 4a3e0860..324798f2 100644 --- a/tests/unit/llm/test_mistralai_llm.py +++ b/tests/unit/llm/test_mistralai_llm.py @@ -96,7 +96,7 @@ def test_mistralai_llm_invoke_with_message_history_and_system_instruction( ] question = "What about next season?" - # first invokation - initial instructions + # first invocation - initial instructions res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore assert isinstance(res, LLMResponse) assert res.content == "mistral response" From 9d5914e4011fbf304fb53b89a8a11e5a58ae84e7 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Mon, 17 Feb 2025 19:11:36 +0000 Subject: [PATCH 04/18] Fixed tests --- src/neo4j_graphrag/generation/graphrag.py | 23 ++++-- src/neo4j_graphrag/message_history.py | 4 +- tests/unit/test_graphrag.py | 90 +++++++++++++++++++++-- tests/unit/test_message_history.py | 8 +- 4 files changed, 106 insertions(+), 19 deletions(-) diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index 6b764716..00dd213b 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -16,7 +16,7 @@ import logging import warnings -from typing import Any, Optional +from typing import Any, List, Optional, Union from pydantic import ValidationError @@ -28,6 +28,7 @@ from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel from neo4j_graphrag.llm import LLMInterface from neo4j_graphrag.llm.types import LLMMessage +from neo4j_graphrag.message_history import InMemoryMessageHistory, MessageHistory from neo4j_graphrag.retrievers.base import Retriever from neo4j_graphrag.types import RetrieverResult @@ -84,7 +85,7 @@ def __init__( def search( self, query_text: str = "", - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, examples: str = "", retriever_config: Optional[dict[str, Any]] = None, return_context: bool | None = None, @@ -127,6 +128,8 @@ def search( ) except ValidationError as e: raise SearchValidationError(e.errors()) + if isinstance(message_history, list): + message_history = InMemoryMessageHistory(messages=message_history) query = self.build_query(validated_data.query_text, message_history) retriever_result: RetrieverResult = self.retriever.search( query_text=query, **validated_data.retriever_config @@ -139,7 +142,7 @@ def search( logger.debug(f"RAG: prompt={prompt}") answer = self.llm.invoke( prompt, - message_history, + message_history.messages if message_history else None, system_instruction=self.prompt_template.system_instructions, ) result: dict[str, Any] = {"answer": answer.content} @@ -148,10 +151,14 @@ def search( return RagResultModel(**result) def build_query( - self, query_text: str, message_history: Optional[list[LLMMessage]] = None + self, + query_text: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, ) -> str: summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words." if message_history: + if isinstance(message_history, list): + message_history = InMemoryMessageHistory(messages=message_history) summarization_prompt = self.chat_summary_prompt( message_history=message_history ) @@ -162,10 +169,14 @@ def build_query( return self.conversation_prompt(summary=summary, current_query=query_text) return query_text - def chat_summary_prompt(self, message_history: list[LLMMessage]) -> str: + def chat_summary_prompt( + self, message_history: Union[List[LLMMessage], MessageHistory] + ) -> str: + if isinstance(message_history, list): + message_history = InMemoryMessageHistory(messages=message_history) message_list = [ ": ".join([f"{value}" for _, value in message.items()]) - for message in message_history + for message in message_history.messages ] history = "\n".join(message_list) return f""" diff --git a/src/neo4j_graphrag/message_history.py b/src/neo4j_graphrag/message_history.py index 4c7d3650..3aeb9cae 100644 --- a/src/neo4j_graphrag/message_history.py +++ b/src/neo4j_graphrag/message_history.py @@ -57,8 +57,8 @@ def clear(self) -> None: ... class InMemoryMessageHistory(MessageHistory): - def __init__(self, messages: List[LLMMessage] = []) -> None: - self._messages = messages + def __init__(self, messages: Optional[List[LLMMessage]] = None) -> None: + self._messages = messages or [] @property def messages(self) -> List[LLMMessage]: diff --git a/tests/unit/test_graphrag.py b/tests/unit/test_graphrag.py index ffacd10f..204070c4 100644 --- a/tests/unit/test_graphrag.py +++ b/tests/unit/test_graphrag.py @@ -21,6 +21,8 @@ from neo4j_graphrag.generation.prompts import RagTemplate from neo4j_graphrag.generation.types import RagResultModel from neo4j_graphrag.llm import LLMResponse +from neo4j_graphrag.llm.types import LLMMessage +from neo4j_graphrag.message_history import InMemoryMessageHistory from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem @@ -114,14 +116,14 @@ def test_graphrag_happy_path_with_message_history( question """ - first_invokation_input = """ + first_invocation_input = """ Summarize the message history: user: initial question assistant: answer to initial question """ - first_invokation_system_instruction = "You are a summarization assistant. Summarize the given text in no more than 300 words." - second_invokation = """Context: + first_invocation_system_instruction = "You are a summarization assistant. Summarize the given text in no more than 300 words." + second_invocation = """Context: item content 1 item content 2 @@ -141,11 +143,11 @@ def test_graphrag_happy_path_with_message_history( llm.invoke.assert_has_calls( [ call( - input=first_invokation_input, - system_instruction=first_invokation_system_instruction, + input=first_invocation_input, + system_instruction=first_invocation_system_instruction, ), call( - second_invokation, + second_invocation, message_history, system_instruction="Answer the user question using the provided context.", ), @@ -157,6 +159,82 @@ def test_graphrag_happy_path_with_message_history( assert res.retriever_result is None +def test_graphrag_happy_path_with_in_memory_message_history( + retriever_mock: MagicMock, llm: MagicMock +) -> None: + rag = GraphRAG( + retriever=retriever_mock, + llm=llm, + ) + retriever_mock.search.return_value = RetrieverResult( + items=[ + RetrieverResultItem(content="item content 1"), + RetrieverResultItem(content="item content 2"), + ] + ) + llm.invoke.side_effect = [ + LLMResponse(content="llm generated summary"), + LLMResponse(content="llm generated text"), + ] + message_history = InMemoryMessageHistory( + messages=[ + LLMMessage(role="user", content="initial question"), + LLMMessage(role="assistant", content="answer to initial question"), + ] + ) + res = rag.search("question", message_history) + + expected_retriever_query_text = """ +Message Summary: +llm generated summary + +Current Query: +question +""" + + first_invocation_input = """ +Summarize the message history: + +user: initial question +assistant: answer to initial question +""" + first_invocation_system_instruction = "You are a summarization assistant. Summarize the given text in no more than 300 words." + second_invocation = """Context: +item content 1 +item content 2 + +Examples: + + +Question: +question + +Answer: +""" + + retriever_mock.search.assert_called_once_with( + query_text=expected_retriever_query_text + ) + assert llm.invoke.call_count == 2 + llm.invoke.assert_has_calls( + [ + call( + input=first_invocation_input, + system_instruction=first_invocation_system_instruction, + ), + call( + second_invocation, + message_history.messages, + system_instruction="Answer the user question using the provided context.", + ), + ] + ) + + assert isinstance(res, RagResultModel) + assert res.answer == "llm generated text" + assert res.retriever_result is None + + def test_graphrag_happy_path_custom_system_instruction( retriever_mock: MagicMock, llm: MagicMock ) -> None: diff --git a/tests/unit/test_message_history.py b/tests/unit/test_message_history.py index 533235ce..42709189 100644 --- a/tests/unit/test_message_history.py +++ b/tests/unit/test_message_history.py @@ -62,9 +62,7 @@ def test_neo4j_message_history_invalid_session_id(driver: MagicMock) -> None: def test_neo4j_message_history_invalid_driver() -> None: with pytest.raises(ValidationError) as exc_info: Neo4jMessageHistory(session_id="123", driver=1.5, node_label="123", window=1) # type: ignore[arg-type] - assert "Input should be a valid dictionary or instance of Neo4jDriver" in str( - exc_info.value - ) + assert "Input should be an instance of Driver" in str(exc_info.value) def test_neo4j_message_history_invalid_node_label(driver: MagicMock) -> None: @@ -81,8 +79,8 @@ def test_neo4j_message_history_invalid_window(driver: MagicMock) -> None: assert "Input should be greater than 0" in str(exc_info.value) -def test_neo4j_message_history_messages_setter(neo4j_driver: MagicMock) -> None: - message_history = Neo4jMessageHistory(session_id="123", driver=neo4j_driver) +def test_neo4j_message_history_messages_setter(driver: MagicMock) -> None: + message_history = Neo4jMessageHistory(session_id="123", driver=driver) with pytest.raises(NotImplementedError) as exc_info: message_history.messages = [ LLMMessage(role="user", content="may thy knife chip and shatter"), From 05299a1ea1dc6292c7cc66ac4ca776bddcb49705 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Tue, 18 Feb 2025 11:18:55 +0000 Subject: [PATCH 05/18] Added test_graphrag_happy_path_with_neo4j_message_history --- src/neo4j_graphrag/generation/graphrag.py | 5 +- tests/e2e/test_graphrag_e2e.py | 95 ++++++++++++++++++++++- 2 files changed, 95 insertions(+), 5 deletions(-) diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index 00dd213b..ad454681 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -103,7 +103,8 @@ def search( Args: query_text (str): The user question. - message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. examples (str): Examples added to the LLM prompt. retriever_config (Optional[dict]): Parameters passed to the retriever. search method; e.g.: top_k @@ -175,7 +176,7 @@ def chat_summary_prompt( if isinstance(message_history, list): message_history = InMemoryMessageHistory(messages=message_history) message_list = [ - ": ".join([f"{value}" for _, value in message.items()]) + f"{message['role']}: {message['content']}" for message in message_history.messages ] history = "\n".join(message_list) diff --git a/tests/e2e/test_graphrag_e2e.py b/tests/e2e/test_graphrag_e2e.py index d8695756..30e291e2 100644 --- a/tests/e2e/test_graphrag_e2e.py +++ b/tests/e2e/test_graphrag_e2e.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock +from unittest.mock import MagicMock, call import neo4j import pytest @@ -21,8 +21,10 @@ from neo4j_graphrag.generation.graphrag import GraphRAG from neo4j_graphrag.generation.types import RagResultModel from neo4j_graphrag.llm import LLMResponse +from neo4j_graphrag.llm.types import LLMMessage +from neo4j_graphrag.message_history import Neo4jMessageHistory from neo4j_graphrag.retrievers import VectorCypherRetriever -from neo4j_graphrag.types import RetrieverResult +from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem from tests.e2e.conftest import BiologyEmbedder from tests.e2e.utils import build_data_objects, populate_neo4j @@ -79,6 +81,93 @@ def test_graphrag_happy_path( assert result.retriever_result is None +@pytest.mark.usefixtures("populate_neo4j_db") +def test_graphrag_happy_path_with_neo4j_message_history( + retriever_mock: MagicMock, + llm: MagicMock, + driver: neo4j.Driver, +) -> None: + rag = GraphRAG( + retriever=retriever_mock, + llm=llm, + ) + retriever_mock.search.return_value = RetrieverResult( + items=[ + RetrieverResultItem(content="item content 1"), + RetrieverResultItem(content="item content 2"), + ] + ) + llm.invoke.side_effect = [ + LLMResponse(content="llm generated summary"), + LLMResponse(content="llm generated text"), + ] + message_history = Neo4jMessageHistory( + driver=driver, + session_id="123", + node_label="Message", + ) + message_history.clear() + message_history.add_messages( + messages=[ + LLMMessage(role="user", content="initial question"), + LLMMessage(role="assistant", content="answer to initial question"), + ] + ) + res = rag.search( + query_text="question", + message_history=message_history, + ) + expected_retriever_query_text = """ +Message Summary: +llm generated summary + +Current Query: +question +""" + + first_invocation_input = """ +Summarize the message history: + +user: initial question +assistant: answer to initial question +""" + first_invocation_system_instruction = "You are a summarization assistant. Summarize the given text in no more than 300 words." + second_invocation = """Context: +item content 1 +item content 2 + +Examples: + + +Question: +question + +Answer: +""" + retriever_mock.search.assert_called_once_with( + query_text=expected_retriever_query_text + ) + assert llm.invoke.call_count == 2 + llm.invoke.assert_has_calls( + [ + call( + input=first_invocation_input, + system_instruction=first_invocation_system_instruction, + ), + call( + second_invocation, + message_history.messages, + system_instruction="Answer the user question using the provided context.", + ), + ] + ) + + assert isinstance(res, RagResultModel) + assert res.answer == "llm generated text" + assert res.retriever_result is None + message_history.clear() + + @pytest.mark.usefixtures("populate_neo4j_db") def test_graphrag_happy_path_return_context( driver: MagicMock, llm: MagicMock, biology_embedder: BiologyEmbedder @@ -127,7 +216,7 @@ def test_graphrag_happy_path_return_context( @pytest.mark.usefixtures("populate_neo4j_db") def test_graphrag_happy_path_examples( - driver: MagicMock, llm: MagicMock, biology_embedder: BiologyEmbedder + driver: MagicMock, llm: MagicMock, biology_embedder: MagicMock ) -> None: retriever = VectorCypherRetriever( driver, From cda7962c187d496c1b053d46c9cedf76278c18ef Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Tue, 18 Feb 2025 11:48:18 +0000 Subject: [PATCH 06/18] Updated LLMs --- examples/customize/llms/custom_llm.py | 7 ++++--- src/neo4j_graphrag/llm/anthropic_llm.py | 23 +++++++++++++++++------ src/neo4j_graphrag/llm/base.py | 14 +++++++++----- src/neo4j_graphrag/llm/cohere_llm.py | 21 +++++++++++++++------ src/neo4j_graphrag/llm/mistralai_llm.py | 20 ++++++++++++++------ src/neo4j_graphrag/llm/ollama_llm.py | 21 +++++++++++++++------ src/neo4j_graphrag/llm/openai_llm.py | 22 ++++++++++++++++------ src/neo4j_graphrag/llm/vertexai_llm.py | 23 +++++++++++++++++------ 8 files changed, 107 insertions(+), 44 deletions(-) diff --git a/examples/customize/llms/custom_llm.py b/examples/customize/llms/custom_llm.py index 322d8d23..cc784cf8 100644 --- a/examples/customize/llms/custom_llm.py +++ b/examples/customize/llms/custom_llm.py @@ -1,9 +1,10 @@ import random import string -from typing import Any, Optional +from typing import Any, List, Optional, Union from neo4j_graphrag.llm import LLMInterface, LLMResponse from neo4j_graphrag.llm.types import LLMMessage +from neo4j_graphrag.message_history import MessageHistory class CustomLLM(LLMInterface): @@ -15,7 +16,7 @@ def __init__( def invoke( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: content: str = ( @@ -26,7 +27,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: raise NotImplementedError() diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index e219a521..f5af6b1c 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, Optional, cast +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, cast from pydantic import ValidationError @@ -26,6 +26,7 @@ MessageList, UserMessage, ) +from neo4j_graphrag.message_history import MessageHistory if TYPE_CHECKING: from anthropic.types.message_param import MessageParam @@ -76,10 +77,14 @@ def __init__( self.async_client = anthropic.AsyncAnthropic(**kwargs) def get_messages( - self, input: str, message_history: Optional[list[LLMMessage]] = None + self, + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, ) -> Iterable[MessageParam]: messages: list[dict[str, str]] = [] if message_history: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages try: MessageList(messages=cast(list[BaseMessage], message_history)) except ValidationError as e: @@ -91,20 +96,23 @@ def get_messages( def invoke( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages messages = self.get_messages(input, message_history) response = self.client.messages.create( model=self.model_name, @@ -124,20 +132,23 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages messages = self.get_messages(input, message_history) response = await self.async_client.messages.create( model=self.model_name, diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index 25ebad73..2a4b25eb 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -15,7 +15,9 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, List, Optional, Union + +from neo4j_graphrag.message_history import MessageHistory from .types import ( LLMMessage, @@ -45,14 +47,15 @@ def __init__( def invoke( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the LLM and retrieves a response. Args: input (str): Text sent to the LLM. - message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: @@ -66,14 +69,15 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends a text input to the LLM and retrieves a response. Args: input (str): Text sent to the LLM. - message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index 439c94fc..5325a799 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -14,7 +14,7 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, Optional, cast +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, cast from pydantic import ValidationError @@ -28,6 +28,7 @@ SystemMessage, UserMessage, ) +from neo4j_graphrag.message_history import MessageHistory if TYPE_CHECKING: from cohere import ChatMessages @@ -78,13 +79,15 @@ def __init__( def get_messages( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> ChatMessages: messages = [] if system_instruction: messages.append(SystemMessage(content=system_instruction).model_dump()) if message_history: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages try: MessageList(messages=cast(list[BaseMessage], message_history)) except ValidationError as e: @@ -96,20 +99,23 @@ def get_messages( def invoke( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages messages = self.get_messages(input, message_history, system_instruction) res = self.client.chat( messages=messages, @@ -124,20 +130,23 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages messages = self.get_messages(input, message_history, system_instruction) res = self.async_client.chat( messages=messages, diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index 7b603653..e1ed935d 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -15,7 +15,7 @@ from __future__ import annotations import os -from typing import Any, Iterable, Optional, cast +from typing import Any, Iterable, List, Optional, Union, cast from pydantic import ValidationError @@ -29,6 +29,7 @@ SystemMessage, UserMessage, ) +from neo4j_graphrag.message_history import MessageHistory try: from mistralai import Messages, Mistral @@ -68,13 +69,15 @@ def __init__( def get_messages( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> list[Messages]: messages = [] if system_instruction: messages.append(SystemMessage(content=system_instruction).model_dump()) if message_history: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages try: MessageList(messages=cast(list[BaseMessage], message_history)) except ValidationError as e: @@ -86,7 +89,7 @@ def get_messages( def invoke( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the Mistral chat completion model @@ -94,7 +97,7 @@ def invoke( Args: input (str): Text sent to the LLM. - message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, with each message having a specific role assigned. system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: @@ -104,6 +107,8 @@ def invoke( LLMGenerationError: If anything goes wrong. """ try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages messages = self.get_messages(input, message_history, system_instruction) response = self.client.chat.complete( model=self.model_name, @@ -122,7 +127,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends a text input to the MistralAI chat @@ -130,7 +135,8 @@ async def ainvoke( Args: input (str): Text sent to the LLM. - message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: @@ -140,6 +146,8 @@ async def ainvoke( LLMGenerationError: If anything goes wrong. """ try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages messages = self.get_messages(input, message_history, system_instruction) response = await self.client.chat.complete_async( model=self.model_name, diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index 572e56bd..f4e1bde6 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -14,11 +14,12 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, Optional, Sequence, cast +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, Union, cast from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError +from neo4j_graphrag.message_history import MessageHistory from .base import LLMInterface from .types import ( @@ -60,13 +61,15 @@ def __init__( def get_messages( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> Sequence[Message]: messages = [] if system_instruction: messages.append(SystemMessage(content=system_instruction).model_dump()) if message_history: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages try: MessageList(messages=cast(list[BaseMessage], message_history)) except ValidationError as e: @@ -78,20 +81,23 @@ def get_messages( def invoke( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages response = self.client.chat( model=self.model_name, messages=self.get_messages(input, message_history, system_instruction), @@ -105,7 +111,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends a text input to the OpenAI chat @@ -113,7 +119,8 @@ async def ainvoke( Args: input (str): Text sent to the LLM. - message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: @@ -123,6 +130,8 @@ async def ainvoke( LLMGenerationError: If anything goes wrong. """ try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages response = await self.async_client.chat( model=self.model_name, messages=self.get_messages(input, message_history, system_instruction), diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 38ec27a0..bcd26d60 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -15,10 +15,12 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, Any, Iterable, Optional, cast +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, cast from pydantic import ValidationError +from neo4j_graphrag.message_history import MessageHistory + from ..exceptions import LLMGenerationError from .base import LLMInterface from .types import ( @@ -68,13 +70,15 @@ def __init__( def get_messages( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> Iterable[ChatCompletionMessageParam]: messages = [] if system_instruction: messages.append(SystemMessage(content=system_instruction).model_dump()) if message_history: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages try: MessageList(messages=cast(list[BaseMessage], message_history)) except ValidationError as e: @@ -86,7 +90,7 @@ def get_messages( def invoke( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the OpenAI chat completion model @@ -94,7 +98,8 @@ def invoke( Args: input (str): Text sent to the LLM. - message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: @@ -104,6 +109,8 @@ def invoke( LLMGenerationError: If anything goes wrong. """ try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages response = self.client.chat.completions.create( messages=self.get_messages(input, message_history, system_instruction), model=self.model_name, @@ -117,7 +124,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends a text input to the OpenAI chat @@ -125,7 +132,8 @@ async def ainvoke( Args: input (str): Text sent to the LLM. - message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: @@ -135,6 +143,8 @@ async def ainvoke( LLMGenerationError: If anything goes wrong. """ try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages response = await self.async_client.chat.completions.create( messages=self.get_messages(input, message_history, system_instruction), model=self.model_name, diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 2b79264e..917d7de1 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -13,13 +13,14 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Optional, cast +from typing import Any, List, Optional, Union, cast from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface from neo4j_graphrag.llm.types import BaseMessage, LLMMessage, LLMResponse, MessageList +from neo4j_graphrag.message_history import MessageHistory try: from vertexai.generative_models import ( @@ -77,10 +78,14 @@ def __init__( self.options = kwargs def get_messages( - self, input: str, message_history: Optional[list[LLMMessage]] = None + self, + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, ) -> list[Content]: messages = [] if message_history: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages try: MessageList(messages=cast(list[BaseMessage], message_history)) except ValidationError as e: @@ -106,14 +111,15 @@ def get_messages( def invoke( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: @@ -126,6 +132,8 @@ def invoke( **self.options, ) try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages messages = self.get_messages(input, message_history) response = self.model.generate_content(messages, **self.model_params) return LLMResponse(content=response.text) @@ -135,20 +143,23 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[LLMMessage]] = None, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages system_message = ( [system_instruction] if system_instruction is not None else [] ) From a2cd51831d5d9f9df6da6a2deea465b7990521f3 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Tue, 18 Feb 2025 11:50:28 +0000 Subject: [PATCH 07/18] Added missing copyright headers --- src/neo4j_graphrag/message_history.py | 14 ++++++++++++++ tests/e2e/test_message_history_e2e.py | 14 ++++++++++++++ tests/unit/test_message_history.py | 14 ++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/src/neo4j_graphrag/message_history.py b/src/neo4j_graphrag/message_history.py index 3aeb9cae..ccdffba7 100644 --- a/src/neo4j_graphrag/message_history.py +++ b/src/neo4j_graphrag/message_history.py @@ -1,3 +1,17 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from abc import ABC, abstractmethod from typing import List, Optional, Union diff --git a/tests/e2e/test_message_history_e2e.py b/tests/e2e/test_message_history_e2e.py index f6413bda..274c8fb0 100644 --- a/tests/e2e/test_message_history_e2e.py +++ b/tests/e2e/test_message_history_e2e.py @@ -1,3 +1,17 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import neo4j from neo4j_graphrag.llm.types import LLMMessage from neo4j_graphrag.message_history import Neo4jMessageHistory diff --git a/tests/unit/test_message_history.py b/tests/unit/test_message_history.py index 42709189..e67700e0 100644 --- a/tests/unit/test_message_history.py +++ b/tests/unit/test_message_history.py @@ -1,3 +1,17 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from unittest.mock import MagicMock import pytest From f248dbe674cfddbf8edf03d115c655f649fe1fcd Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Tue, 18 Feb 2025 12:27:26 +0000 Subject: [PATCH 08/18] Refactored graphrag --- src/neo4j_graphrag/generation/graphrag.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index ad454681..53261fcc 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -28,7 +28,7 @@ from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel from neo4j_graphrag.llm import LLMInterface from neo4j_graphrag.llm.types import LLMMessage -from neo4j_graphrag.message_history import InMemoryMessageHistory, MessageHistory +from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.retrievers.base import Retriever from neo4j_graphrag.types import RetrieverResult @@ -129,8 +129,8 @@ def search( ) except ValidationError as e: raise SearchValidationError(e.errors()) - if isinstance(message_history, list): - message_history = InMemoryMessageHistory(messages=message_history) + if isinstance(message_history, MessageHistory): + message_history = message_history.messages query = self.build_query(validated_data.query_text, message_history) retriever_result: RetrieverResult = self.retriever.search( query_text=query, **validated_data.retriever_config @@ -143,7 +143,7 @@ def search( logger.debug(f"RAG: prompt={prompt}") answer = self.llm.invoke( prompt, - message_history.messages if message_history else None, + message_history, system_instruction=self.prompt_template.system_instructions, ) result: dict[str, Any] = {"answer": answer.content} @@ -158,8 +158,8 @@ def build_query( ) -> str: summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words." if message_history: - if isinstance(message_history, list): - message_history = InMemoryMessageHistory(messages=message_history) + if isinstance(message_history, MessageHistory): + message_history = message_history.messages summarization_prompt = self.chat_summary_prompt( message_history=message_history ) @@ -173,11 +173,10 @@ def build_query( def chat_summary_prompt( self, message_history: Union[List[LLMMessage], MessageHistory] ) -> str: - if isinstance(message_history, list): - message_history = InMemoryMessageHistory(messages=message_history) + if isinstance(message_history, MessageHistory): + message_history = message_history.messages message_list = [ - f"{message['role']}: {message['content']}" - for message in message_history.messages + f"{message['role']}: {message['content']}" for message in message_history ] history = "\n".join(message_list) return f""" From 72c7070c84ad001a2b090fb769b30139d55c2a8b Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Tue, 18 Feb 2025 15:17:33 +0000 Subject: [PATCH 09/18] Added docstrings to message history classes --- src/neo4j_graphrag/message_history.py | 53 +++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/neo4j_graphrag/message_history.py b/src/neo4j_graphrag/message_history.py index ccdffba7..361e1f0a 100644 --- a/src/neo4j_graphrag/message_history.py +++ b/src/neo4j_graphrag/message_history.py @@ -55,6 +55,8 @@ class MessageHistory(ABC): + """Abstract base class for message history storage.""" + @property @abstractmethod def messages(self) -> List[LLMMessage]: ... @@ -71,6 +73,24 @@ def clear(self) -> None: ... class InMemoryMessageHistory(MessageHistory): + """Message history stored in memory + + Example: + + .. code-block:: python + + from neo4j_graphrag.llm.types import LLMMessage + from neo4j_graphrag.message_history import InMemoryMessageHistory + + history = InMemoryMessageHistory() + + message = LLMMessage(role="user", content="Hello!") + history.add_message(message) + + Args: + messages (Optional[List[LLMMessage]]): List of messages to initialize the history with. Defaults to None. + """ + def __init__(self, messages: Optional[List[LLMMessage]] = None) -> None: self._messages = messages or [] @@ -89,6 +109,33 @@ def clear(self) -> None: class Neo4jMessageHistory(MessageHistory): + """Message history stored in a Neo4j database + + Example: + + .. code-block:: python + + import neo4j + from neo4j_graphrag.llm.types import LLMMessage + from neo4j_graphrag.message_history import Neo4jMessageHistory + + driver = neo4j.GraphDatabase.driver(URI, auth=AUTH) + + history = Neo4jMessageHistory( + session_id="123", driver=driver, node_label="Message", window=10 + ) + + message = LLMMessage(role="user", content="Hello!") + history.add_message(message) + + Args: + session_id (Union[str, int]): Unique identifier for the chat session. + driver (neo4j.Driver): Neo4j driver instance. + node_label (str, optional): Label used for session nodes in Neo4j. Defaults to "Session". + window (Optional[PositiveInt], optional): Number of previous messages to return when retrieving messages. + + """ + def __init__( self, session_id: Union[str, int], @@ -139,6 +186,11 @@ def messages(self, messages: List[LLMMessage]) -> None: ) def add_message(self, message: LLMMessage) -> None: + """Add a message to the message history. + + Args: + message (LLMMessage): The message to add. + """ self._driver.execute_query( query_=ADD_MESSAGE_QUERY.format(node_label=self._node_label), parameters_={ @@ -149,6 +201,7 @@ def add_message(self, message: LLMMessage) -> None: ) def clear(self) -> None: + """Clear the message history.""" self._driver.execute_query( query_=CLEAR_SESSION_QUERY.format(node_label=self._node_label), parameters_={"session_id": self._session_id}, From 0f3c6de5e53fe8217423db98ad4fee7b9d418b6d Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Tue, 18 Feb 2025 15:27:58 +0000 Subject: [PATCH 10/18] Added message history examples --- examples/README.md | 3 +- .../llms/llm_with_neo4j_message_history.py | 61 +++++++++++++ .../graphrag_with_neo4j_message_history.py | 89 +++++++++++++++++++ 3 files changed, 152 insertions(+), 1 deletion(-) create mode 100644 examples/customize/llms/llm_with_neo4j_message_history.py create mode 100644 examples/question_answering/graphrag_with_neo4j_message_history.py diff --git a/examples/README.md b/examples/README.md index 7a53aaec..3dd6ee65 100644 --- a/examples/README.md +++ b/examples/README.md @@ -52,7 +52,7 @@ are listed in [the last section of this file](#customize). - [End to end GraphRAG](./answer/graphrag.py) - [GraphRAG with message history](./question_answering/graphrag_with_message_history.py) - +- [GraphRAG with Neo4j message history](./question_answering/graphrag_with_neo4j_message_history.py) ## Customize @@ -75,6 +75,7 @@ are listed in [the last section of this file](#customize). - [Custom LLM](./customize/llms/custom_llm.py) - [Message history](./customize/llms/llm_with_message_history.py) +- [Message history with Neo4j](./customize/llms/llm_with_neo4j_message_history.py) - [System Instruction](./customize/llms/llm_with_system_instructions.py) diff --git a/examples/customize/llms/llm_with_neo4j_message_history.py b/examples/customize/llms/llm_with_neo4j_message_history.py new file mode 100644 index 00000000..5782ff2c --- /dev/null +++ b/examples/customize/llms/llm_with_neo4j_message_history.py @@ -0,0 +1,61 @@ +"""This example illustrates the message_history feature +of the LLMInterface by mocking a conversation between a user +and an LLM about Tom Hanks. + +Neo4j is used as the database for storing the message history. + +OpenAILLM can be replaced by any supported LLM from this package. +""" + +import neo4j +from neo4j_graphrag.llm import LLMResponse, OpenAILLM +from neo4j_graphrag.message_history import Neo4jMessageHistory + +# Define database credentials +URI = "neo4j+s://demo.neo4jlabs.com" +AUTH = ("recommendations", "recommendations") +DATABASE = "recommendations" +INDEX = "moviePlotsEmbedding" + +# set api key here on in the OPENAI_API_KEY env var +api_key = None + +llm = OpenAILLM(model_name="gpt-4o", api_key=api_key) + +questions = [ + "What are some movies Tom Hanks starred in?", + "Is he also a director?", + "Wow, that's impressive. And what about his personal life, does he have children?", +] + +driver = neo4j.GraphDatabase.driver( + URI, + auth=AUTH, + database=DATABASE, +) + +history = Neo4jMessageHistory( + session_id="123", driver=driver, node_label="Message", window=10 +) + +for question in questions: + res: LLMResponse = llm.invoke( + question, + message_history=history, + ) + history.add_message( + { + "role": "user", + "content": question, + } + ) + history.add_message( + { + "role": "assistant", + "content": res.content, + } + ) + + print("#" * 50, question) + print(res.content) + print("#" * 50) diff --git a/examples/question_answering/graphrag_with_neo4j_message_history.py b/examples/question_answering/graphrag_with_neo4j_message_history.py new file mode 100644 index 00000000..383a2a25 --- /dev/null +++ b/examples/question_answering/graphrag_with_neo4j_message_history.py @@ -0,0 +1,89 @@ +"""End to end example of building a RAG pipeline backed by a Neo4j database, +simulating a chat with message history which is also stored in Neo4j. + +Requires OPENAI_API_KEY to be in the env var. +""" + +import neo4j +from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings +from neo4j_graphrag.generation import GraphRAG +from neo4j_graphrag.llm import OpenAILLM +from neo4j_graphrag.message_history import Neo4jMessageHistory +from neo4j_graphrag.retrievers import VectorCypherRetriever + +# Define database credentials +URI = "neo4j+s://demo.neo4jlabs.com" +AUTH = ("recommendations", "recommendations") +DATABASE = "recommendations" +INDEX = "moviePlotsEmbedding" + + +driver = neo4j.GraphDatabase.driver( + URI, + auth=AUTH, +) + +embedder = OpenAIEmbeddings() + +retriever = VectorCypherRetriever( + driver, + index_name=INDEX, + retrieval_query=""" + WITH node as movie, score + CALL(movie) { + MATCH (movie)<-[:ACTED_IN]-(p:Person) + RETURN collect(p.name) as actors + } + CALL(movie) { + MATCH (movie)<-[:DIRECTED]-(p:Person) + RETURN collect(p.name) as directors + } + RETURN movie.title as title, movie.plot as plot, movie.year as year, actors, directors + """, + embedder=embedder, + neo4j_database=DATABASE, +) + +llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0}) + +rag = GraphRAG( + retriever=retriever, + llm=llm, +) + +history = Neo4jMessageHistory( + session_id="123", driver=driver, node_label="Message", window=10 +) + +questions = [ + "Who starred in the Apollo 13 movies?", + "Who was its director?", + "In which year was this movie released?", +] + +for question in questions: + result = rag.search( + question, + return_context=False, + message_history=history, + ) + + answer = result.answer + print("#" * 50, question) + print(answer) + print("#" * 50) + + history.add_message( + { + "role": "user", + "content": question, + } + ) + history.add_message( + { + "role": "assistant", + "content": answer, + } + ) + +driver.close() From c3e671c841526c9ecc23b42f373aa1518fbb1b8e Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Tue, 18 Feb 2025 15:32:53 +0000 Subject: [PATCH 11/18] Updated docs --- docs/source/api.rst | 9 +++++++++ docs/source/index.rst | 1 + docs/source/user_guide_rag.rst | 1 + 3 files changed, 11 insertions(+) diff --git a/docs/source/api.rst b/docs/source/api.rst index 9fc34438..b9498509 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -403,6 +403,15 @@ Database Interaction .. autofunction:: neo4j_graphrag.schema.format_schema +*************** +Message History +*************** + +.. autoclass:: neo4j_graphrag.message_history.InMemoryMessageHistory + +.. autoclass:: neo4j_graphrag.message_history.Neo4jMessageHistory + + ****** Errors ****** diff --git a/docs/source/index.rst b/docs/source/index.rst index d5a330e4..ae8c25d6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -148,6 +148,7 @@ Note that the below example is not the only way you can upsert data into your Ne .. code:: python + from neo4j import GraphDatabase from neo4j_graphrag.indexes import upsert_vectors from neo4j_graphrag.types import EntityType diff --git a/docs/source/user_guide_rag.rst b/docs/source/user_guide_rag.rst index 9b7d0cd1..f8c2c421 100644 --- a/docs/source/user_guide_rag.rst +++ b/docs/source/user_guide_rag.rst @@ -917,6 +917,7 @@ Populate a Vector Index ========================== .. code:: python + from random import random from neo4j import GraphDatabase From 15bfd3f2125f87ca446d5e69ecec0d62329b6d54 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Tue, 18 Feb 2025 15:41:32 +0000 Subject: [PATCH 12/18] Updated CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0006747f..609eb7ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ - Support for effective_search_ratio parameter in vector and hybrid searches. - Introduced upsert_vectors utility function for batch upserting embeddings to vector indexes. - Introduced `extract_cypher` function to enhance Cypher query extraction and formatting in `Text2CypherRetriever`. +- Introduced Neo4jMessageHistory and InMemoryMessageHistory classes for managing LLM message histories. +- Added examples and documentation for using message history with Neo4j and in-memory storage. +- Updated LLM and GraphRAG classes to support new message history classes. ### Changed From c9e3c2717302ab67149bdc43f235aea38bee3870 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Wed, 19 Feb 2025 10:46:33 +0000 Subject: [PATCH 13/18] Removed Neo4jMessageHistory __del__ method --- src/neo4j_graphrag/message_history.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/neo4j_graphrag/message_history.py b/src/neo4j_graphrag/message_history.py index 361e1f0a..9deed1a0 100644 --- a/src/neo4j_graphrag/message_history.py +++ b/src/neo4j_graphrag/message_history.py @@ -89,6 +89,7 @@ class InMemoryMessageHistory(MessageHistory): Args: messages (Optional[List[LLMMessage]]): List of messages to initialize the history with. Defaults to None. + """ def __init__(self, messages: Optional[List[LLMMessage]] = None) -> None: @@ -206,7 +207,3 @@ def clear(self) -> None: query_=CLEAR_SESSION_QUERY.format(node_label=self._node_label), parameters_={"session_id": self._session_id}, ) - - def __del__(self) -> None: - if self._driver: - self._driver.close() From 5a873dc35bc992e927a2fe95e60df4c9d3227884 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Wed, 19 Feb 2025 11:06:46 +0000 Subject: [PATCH 14/18] Makes the build_query and chat_summary_prompt methods in the GraphRAG class private --- src/neo4j_graphrag/generation/graphrag.py | 16 +++++----------- tests/unit/test_graphrag.py | 2 +- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index 53261fcc..81de3ccb 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -131,7 +131,7 @@ def search( raise SearchValidationError(e.errors()) if isinstance(message_history, MessageHistory): message_history = message_history.messages - query = self.build_query(validated_data.query_text, message_history) + query = self._build_query(validated_data.query_text, message_history) retriever_result: RetrieverResult = self.retriever.search( query_text=query, **validated_data.retriever_config ) @@ -151,16 +151,14 @@ def search( result["retriever_result"] = retriever_result return RagResultModel(**result) - def build_query( + def _build_query( self, query_text: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + message_history: Optional[List[LLMMessage]] = None, ) -> str: summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words." if message_history: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - summarization_prompt = self.chat_summary_prompt( + summarization_prompt = self._chat_summary_prompt( message_history=message_history ) summary = self.llm.invoke( @@ -170,11 +168,7 @@ def build_query( return self.conversation_prompt(summary=summary, current_query=query_text) return query_text - def chat_summary_prompt( - self, message_history: Union[List[LLMMessage], MessageHistory] - ) -> str: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages + def _chat_summary_prompt(self, message_history: List[LLMMessage]) -> str: message_list = [ f"{message['role']}: {message['content']}" for message in message_history ] diff --git a/tests/unit/test_graphrag.py b/tests/unit/test_graphrag.py index 204070c4..101980e9 100644 --- a/tests/unit/test_graphrag.py +++ b/tests/unit/test_graphrag.py @@ -294,7 +294,7 @@ def test_chat_summary_template(retriever_mock: MagicMock, llm: MagicMock) -> Non retriever=retriever_mock, llm=llm, ) - prompt = rag.chat_summary_prompt(message_history=message_history) # type: ignore + prompt = rag._chat_summary_prompt(message_history=message_history) # type: ignore assert ( prompt == """ From cfcf32433be89267942b60dd9e61e5daac9219bc Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Wed, 19 Feb 2025 11:23:55 +0000 Subject: [PATCH 15/18] Added a threading lock to InMemoryMessageHistory --- src/neo4j_graphrag/message_history.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/neo4j_graphrag/message_history.py b/src/neo4j_graphrag/message_history.py index 9deed1a0..3900cfee 100644 --- a/src/neo4j_graphrag/message_history.py +++ b/src/neo4j_graphrag/message_history.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import threading from abc import ABC, abstractmethod from typing import List, Optional, Union @@ -93,20 +94,25 @@ class InMemoryMessageHistory(MessageHistory): """ def __init__(self, messages: Optional[List[LLMMessage]] = None) -> None: + self._lock = threading.Lock() self._messages = messages or [] @property def messages(self) -> List[LLMMessage]: - return self._messages + with self._lock: + return self._messages.copy() def add_message(self, message: LLMMessage) -> None: - self._messages.append(message) + with self._lock: + self._messages.append(message) def add_messages(self, messages: List[LLMMessage]) -> None: - self._messages.extend(messages) + with self._lock: + self._messages.extend(messages) def clear(self) -> None: - self._messages = [] + with self._lock: + self._messages = [] class Neo4jMessageHistory(MessageHistory): From 6e8e10d54c0cd15f6f45c081e48b5a994033e829 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Wed, 19 Feb 2025 15:47:01 +0000 Subject: [PATCH 16/18] Removed node_label parameter from Neo4jMessageHistory --- .../llms/llm_with_neo4j_message_history.py | 4 +--- .../graphrag_with_neo4j_message_history.py | 4 +--- src/neo4j_graphrag/message_history.py | 15 +++++---------- src/neo4j_graphrag/types.py | 7 ------- tests/e2e/test_graphrag_e2e.py | 1 - tests/unit/test_message_history.py | 14 +++----------- 6 files changed, 10 insertions(+), 35 deletions(-) diff --git a/examples/customize/llms/llm_with_neo4j_message_history.py b/examples/customize/llms/llm_with_neo4j_message_history.py index 5782ff2c..9f973895 100644 --- a/examples/customize/llms/llm_with_neo4j_message_history.py +++ b/examples/customize/llms/llm_with_neo4j_message_history.py @@ -34,9 +34,7 @@ database=DATABASE, ) -history = Neo4jMessageHistory( - session_id="123", driver=driver, node_label="Message", window=10 -) +history = Neo4jMessageHistory(session_id="123", driver=driver, window=10) for question in questions: res: LLMResponse = llm.invoke( diff --git a/examples/question_answering/graphrag_with_neo4j_message_history.py b/examples/question_answering/graphrag_with_neo4j_message_history.py index 383a2a25..6b360c94 100644 --- a/examples/question_answering/graphrag_with_neo4j_message_history.py +++ b/examples/question_answering/graphrag_with_neo4j_message_history.py @@ -51,9 +51,7 @@ llm=llm, ) -history = Neo4jMessageHistory( - session_id="123", driver=driver, node_label="Message", window=10 -) +history = Neo4jMessageHistory(session_id="123", driver=driver, window=10) questions = [ "Who starred in the Apollo 13 movies?", diff --git a/src/neo4j_graphrag/message_history.py b/src/neo4j_graphrag/message_history.py index 3900cfee..9ea0207f 100644 --- a/src/neo4j_graphrag/message_history.py +++ b/src/neo4j_graphrag/message_history.py @@ -129,7 +129,7 @@ class Neo4jMessageHistory(MessageHistory): driver = neo4j.GraphDatabase.driver(URI, auth=AUTH) history = Neo4jMessageHistory( - session_id="123", driver=driver, node_label="Message", window=10 + session_id="123", driver=driver, window=10 ) message = LLMMessage(role="user", content="Hello!") @@ -147,33 +147,28 @@ def __init__( self, session_id: Union[str, int], driver: neo4j.Driver, - node_label: str = "Session", window: Optional[PositiveInt] = None, ) -> None: validated_data = Neo4jMessageHistoryModel( session_id=session_id, driver_model=Neo4jDriverModel(driver=driver), - node_label=node_label, window=window, ) self._driver = validated_data.driver_model.driver self._session_id = validated_data.session_id - self._node_label = validated_data.node_label self._window = ( "" if validated_data.window is None else validated_data.window - 1 ) # Create session node self._driver.execute_query( - query_=CREATE_SESSION_NODE_QUERY.format(node_label=self._node_label), + query_=CREATE_SESSION_NODE_QUERY.format(node_label="Session"), parameters_={"session_id": self._session_id}, ) @property def messages(self) -> List[LLMMessage]: result = self._driver.execute_query( - query_=GET_MESSAGES_QUERY.format( - node_label=self._node_label, window=self._window - ), + query_=GET_MESSAGES_QUERY.format(node_label="Session", window=self._window), parameters_={"session_id": self._session_id}, ) messages = [ @@ -199,7 +194,7 @@ def add_message(self, message: LLMMessage) -> None: message (LLMMessage): The message to add. """ self._driver.execute_query( - query_=ADD_MESSAGE_QUERY.format(node_label=self._node_label), + query_=ADD_MESSAGE_QUERY.format(node_label="Session"), parameters_={ "role": message["role"], "content": message["content"], @@ -210,6 +205,6 @@ def add_message(self, message: LLMMessage) -> None: def clear(self) -> None: """Clear the message history.""" self._driver.execute_query( - query_=CLEAR_SESSION_QUERY.format(node_label=self._node_label), + query_=CLEAR_SESSION_QUERY.format(node_label="Session"), parameters_={"session_id": self._session_id}, ) diff --git a/src/neo4j_graphrag/types.py b/src/neo4j_graphrag/types.py index 74cc35d7..938f0a8d 100644 --- a/src/neo4j_graphrag/types.py +++ b/src/neo4j_graphrag/types.py @@ -256,7 +256,6 @@ class Text2CypherRetrieverModel(BaseModel): class Neo4jMessageHistoryModel(BaseModel): session_id: Union[str, int] driver_model: Neo4jDriverModel - node_label: str = "Session" window: Optional[PositiveInt] = None @field_validator("session_id") @@ -264,9 +263,3 @@ def validate_session_id(cls, v: Union[str, int]) -> Union[str, int]: if isinstance(v, str) and len(v) == 0: raise ValueError("session_id cannot be empty") return v - - @field_validator("node_label") - def validate_node_label(cls, v: str) -> str: - if len(v) == 0: - raise ValueError("node_label cannot be empty") - return v diff --git a/tests/e2e/test_graphrag_e2e.py b/tests/e2e/test_graphrag_e2e.py index 30e291e2..0f0094d3 100644 --- a/tests/e2e/test_graphrag_e2e.py +++ b/tests/e2e/test_graphrag_e2e.py @@ -104,7 +104,6 @@ def test_graphrag_happy_path_with_neo4j_message_history( message_history = Neo4jMessageHistory( driver=driver, session_id="123", - node_label="Message", ) message_history.clear() message_history.add_messages( diff --git a/tests/unit/test_message_history.py b/tests/unit/test_message_history.py index e67700e0..1c27bbb5 100644 --- a/tests/unit/test_message_history.py +++ b/tests/unit/test_message_history.py @@ -69,27 +69,19 @@ def test_in_memory_message_history_clear() -> None: def test_neo4j_message_history_invalid_session_id(driver: MagicMock) -> None: with pytest.raises(ValidationError) as exc_info: - Neo4jMessageHistory(session_id=1.5, driver=driver, node_label="123", window=1) # type: ignore[arg-type] + Neo4jMessageHistory(session_id=1.5, driver=driver, window=1) # type: ignore[arg-type] assert "Input should be a valid string" in str(exc_info.value) def test_neo4j_message_history_invalid_driver() -> None: with pytest.raises(ValidationError) as exc_info: - Neo4jMessageHistory(session_id="123", driver=1.5, node_label="123", window=1) # type: ignore[arg-type] + Neo4jMessageHistory(session_id="123", driver=1.5, window=1) # type: ignore[arg-type] assert "Input should be an instance of Driver" in str(exc_info.value) -def test_neo4j_message_history_invalid_node_label(driver: MagicMock) -> None: - with pytest.raises(ValidationError) as exc_info: - Neo4jMessageHistory(session_id="123", driver=driver, node_label=1.5, window=1) # type: ignore[arg-type] - assert "Input should be a valid string" in str(exc_info.value) - - def test_neo4j_message_history_invalid_window(driver: MagicMock) -> None: with pytest.raises(ValidationError) as exc_info: - Neo4jMessageHistory( - session_id="123", driver=driver, node_label="123", window=-1 - ) + Neo4jMessageHistory(session_id="123", driver=driver, window=-1) assert "Input should be greater than 0" in str(exc_info.value) From f56ba20c355309ddb2024a686a8399b4989b2206 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Thu, 20 Feb 2025 10:02:24 +0000 Subject: [PATCH 17/18] Updated CLEAR_SESSION_QUERY --- src/neo4j_graphrag/message_history.py | 10 ++++++---- tests/e2e/test_graphrag_e2e.py | 1 - tests/e2e/test_message_history_e2e.py | 10 ++++++++++ 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/neo4j_graphrag/message_history.py b/src/neo4j_graphrag/message_history.py index 9ea0207f..47fa939c 100644 --- a/src/neo4j_graphrag/message_history.py +++ b/src/neo4j_graphrag/message_history.py @@ -30,10 +30,12 @@ CREATE_SESSION_NODE_QUERY = "MERGE (s:`{node_label}` {{id:$session_id}})" CLEAR_SESSION_QUERY = ( - "MATCH (s:`{node_label}`)-[:LAST_MESSAGE]->(last_message) " - "WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT]-() " - "WITH p, length(p) AS length ORDER BY length DESC LIMIT 1 " - "UNWIND nodes(p) as node DETACH DELETE node;" + "MATCH (s:`{node_label}`) " + "WHERE s.id = $session_id " + "OPTIONAL MATCH p=(s)-[:LAST_MESSAGE]->(:Message)<-[:NEXT*0..]-(:Message) " + "WITH CASE WHEN p IS NULL THEN [s] ELSE nodes(p) + [s] END AS nodes " + "UNWIND nodes AS node " + "DETACH DELETE node;" ) GET_MESSAGES_QUERY = ( diff --git a/tests/e2e/test_graphrag_e2e.py b/tests/e2e/test_graphrag_e2e.py index 0f0094d3..5b167018 100644 --- a/tests/e2e/test_graphrag_e2e.py +++ b/tests/e2e/test_graphrag_e2e.py @@ -105,7 +105,6 @@ def test_graphrag_happy_path_with_neo4j_message_history( driver=driver, session_id="123", ) - message_history.clear() message_history.add_messages( messages=[ LLMMessage(role="user", content="initial question"), diff --git a/tests/e2e/test_message_history_e2e.py b/tests/e2e/test_message_history_e2e.py index 274c8fb0..f1a174f6 100644 --- a/tests/e2e/test_message_history_e2e.py +++ b/tests/e2e/test_message_history_e2e.py @@ -76,6 +76,16 @@ def test_neo4j_message_history_clear(driver: neo4j.Driver) -> None: assert len(message_history.messages) == 0 +def test_neo4j_message_history_clear_no_messages(driver: neo4j.Driver) -> None: + driver.execute_query(query_="MATCH (n) DETACH DELETE n;") + message_history = Neo4jMessageHistory(session_id="123", driver=driver) + message_history.clear() + results = driver.execute_query( + query_="MATCH (s:`Session`) WHERE s.id = '123' RETURN s" + ) + assert results.records == [] + + def test_neo4j_message_window_size(driver: neo4j.Driver) -> None: driver.execute_query(query_="MATCH (n) DETACH DELETE n;") message_history = Neo4jMessageHistory(session_id="123", driver=driver, window=1) From e84230ad0a9557a50c0daa27759560a530280594 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Thu, 20 Feb 2025 16:25:11 +0000 Subject: [PATCH 18/18] Fixed CLEAR_SESSION_QUERY --- src/neo4j_graphrag/message_history.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/message_history.py b/src/neo4j_graphrag/message_history.py index 47fa939c..44e6383e 100644 --- a/src/neo4j_graphrag/message_history.py +++ b/src/neo4j_graphrag/message_history.py @@ -33,7 +33,7 @@ "MATCH (s:`{node_label}`) " "WHERE s.id = $session_id " "OPTIONAL MATCH p=(s)-[:LAST_MESSAGE]->(:Message)<-[:NEXT*0..]-(:Message) " - "WITH CASE WHEN p IS NULL THEN [s] ELSE nodes(p) + [s] END AS nodes " + "WITH CASE WHEN p IS NULL THEN [s] ELSE nodes(p) END AS nodes " "UNWIND nodes AS node " "DETACH DELETE node;" )