From b58a3ebfec936e66ea941a78ce534d976aa01bb1 Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Thu, 9 Jan 2025 22:02:52 -0800 Subject: [PATCH 01/11] Improve dedupe prompt --- packages/memory_module/core/memory_core.py | 32 ++++++++++++++++++---- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/packages/memory_module/core/memory_core.py b/packages/memory_module/core/memory_core.py index dc018124..4530697b 100644 --- a/packages/memory_module/core/memory_core.py +++ b/packages/memory_module/core/memory_core.py @@ -208,15 +208,35 @@ async def _extract_memory_processing_decision( old_memory_content = "\n".join( [f"{memory.content}" for memory in old_memories] ) - system_message = f"""You are a semantic memory management agent. Your goal is to determine whether this new memory is duplicated with existing old memories. + system_message = f"""You are a semantic memory management agent. Your task is to decide whether the new memory should be added to the memory system or ignored as a duplicate. + Considerations: -- Time-based order: Each old memory has a creation time. Please take creation time into consideration. -- Repeated behavior: If the new memory indicates a repeated idea over a period of time, it should be added to reflect the pattern. -Return value: -- Add: add new memory while keep old memories -- Ignore: indicates that this memory is similar to an older memory and should be ignored +1. Context Overlap: +If the new memory conveys information that is substantially covered by an existing memory, it should be ignored. +If the new memory adds unique or specific information not present in any old memory, it should be added. +2. Granularity of Detail: +Broader or more general memories should not replace specific ones. However, a specific detail can replace a general statement if it conveys the same underlying idea. +For example: +Old memory: “The user enjoys hiking in national parks.” +New memory: “The user enjoys hiking in Yellowstone National Park.” +Result: Ignore (The older memory already encompasses the specific case). +3. Repeated Patterns: +If the new memory reinforces a pattern of behavior over time (e.g., multiple mentions of a recurring habit, preference, or routine), it should be added to reflect this trend. +4. Temporal Relevance: +If the new memory reflects a significant change or update to the old memory, it should be added. +For example: +Old memory: “The user is planning a trip to Japan.” +New memory: “The user has canceled their trip to Japan.” +Result: Add (The new memory reflects a change). + +Process: + 1. Compare the specificity, unique details, and time relevance of the new memory against old memories. + 2. Decide whether to add or ignore based on the considerations above. + 3. Provide a clear and concise justification for your decision. + Here are the old memories: {old_memory_content} + Here is the new memory: {new_memory} created at {str(datetime.datetime.now())} """ # noqa: E501 From 0d57654ba9606a37a943a4e0563753be2ea8dc61 Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Thu, 9 Jan 2025 23:25:27 -0800 Subject: [PATCH 02/11] Add tests --- packages/memory_module/core/memory_core.py | 8 +- scripts/evaluate_memory_decisions.py | 214 +++++++++++++++++++++ tests/memory_module/test_memory_module.py | 8 +- 3 files changed, 224 insertions(+), 6 deletions(-) create mode 100644 scripts/evaluate_memory_decisions.py diff --git a/packages/memory_module/core/memory_core.py b/packages/memory_module/core/memory_core.py index 4530697b..9c12912a 100644 --- a/packages/memory_module/core/memory_core.py +++ b/packages/memory_module/core/memory_core.py @@ -134,7 +134,7 @@ async def process_semantic_messages( if extraction.action == "add" and extraction.facts: for fact in extraction.facts: decision = await self._get_add_memory_processing_decision(fact.text, author_id) - if decision == "ignore": + if decision.decision == "ignore": logger.info(f"Decision to ignore fact {fact.text}") continue metadata = await self._extract_metadata_from_fact(fact.text) @@ -194,10 +194,12 @@ async def remove_messages(self, message_ids: List[str]) -> None: await self.storage.remove_messages(message_ids) logger.info("messages {} are removed".format(",".join(message_ids))) - async def _get_add_memory_processing_decision(self, new_memory_fact: str, user_id: Optional[str]) -> str: + async def _get_add_memory_processing_decision( + self, new_memory_fact: str, user_id: Optional[str] + ) -> ProcessSemanticMemoryDecision: similar_memories = await self.retrieve_memories(new_memory_fact, user_id, None) decision = await self._extract_memory_processing_decision(new_memory_fact, similar_memories, user_id) - return decision.decision + return decision async def _extract_memory_processing_decision( self, new_memory: str, old_memories: List[Memory], user_id: Optional[str] diff --git a/scripts/evaluate_memory_decisions.py b/scripts/evaluate_memory_decisions.py new file mode 100644 index 00000000..61589f3a --- /dev/null +++ b/scripts/evaluate_memory_decisions.py @@ -0,0 +1,214 @@ +import asyncio +import sys +from datetime import datetime, timedelta +from pathlib import Path +from typing import cast +from uuid import uuid4 + +from memory_module import MemoryModule, MemoryModuleConfig, UserMessageInput +from memory_module.core.message_buffer import MessageBuffer +from memory_module.core.message_queue import MessageQueue +from memory_module.core.scheduler import Scheduler +from tqdm import tqdm + +from tests.memory_module.utils import build_llm_config + +# Test cases from before +TEST_CASES = [ + { + "title": "General vs. Specific Detail", + "old_messages": ["I love outdoor activities.", "I often visit national parks."], + "incoming_message": "I enjoy hiking in Yellowstone National Park.", + "expected_decision": "ignore", + "reason": "The old messages already cover the new message’s context.", + }, + { + "title": "Specific Detail vs. General", + "old_messages": ["I really enjoy hiking in Yellowstone National Park.", "I like exploring scenic trails."], + "incoming_message": "I enjoy hiking in national parks.", + "expected_decision": "ignore", + "reason": "The new message is broader and redundant to the old messages.", + }, + { + "title": "Repeated Behavior Over Time", + "old_messages": ["I had coffee at 8 AM yesterday.", "I had coffee at 8 AM this morning."], + "incoming_message": "I had coffee at 8 AM again today.", + "expected_decision": "add", + "reason": "This reinforces a recurring pattern of behavior over time.", + }, + { + "title": "Updated Temporal Context", + "old_messages": ["I’m planning a trip to Japan.", "I’ve been looking at flights to Japan."], + "incoming_message": "I just canceled my trip to Japan.", + "expected_decision": "add", + "reason": "The new message reflects a significant update to the old messages.", + }, + { + "title": "Irrelevant or Unnecessary Update", + "old_messages": ["I prefer tea over coffee.", "I usually drink tea every day."], + "incoming_message": "I like tea.", + "expected_decision": "ignore", + "reason": "The new message does not add any unique or relevant information.", + }, + { + "title": "Redundant Memory with Different Wording", + "old_messages": ["I have an iPhone 12.", "I bought an iPhone 12 back in 2022."], + "incoming_message": "I own an iPhone 12.", + "expected_decision": "ignore", + "reason": "The new message is a rephrased duplicate of the old messages.", + }, + { + "title": "Additional Specific Information", + "old_messages": ["I like playing video games.", "I often play games on my console."], + "incoming_message": "I love playing RPG video games like Final Fantasy.", + "expected_decision": "add", + "reason": "The new message adds specific details about the type of games.", + }, + { + "title": "Contradictory Information", + "old_messages": ["I like cats.", "I have a cat named Whiskers."], + "incoming_message": "Actually, I don’t like cats.", + "expected_decision": "add", + "reason": "The new message reflects a contradiction or change in preference.", + }, + { + "title": "New Memory Completely Unrelated", + "old_messages": ["I love reading mystery novels.", "I’m a big fan of Agatha Christie’s books."], + "incoming_message": "I really enjoy playing soccer.", + "expected_decision": "add", + "reason": "The new message introduces entirely new information.", + }, + { + "title": "Multiple Old Messages with Partial Overlap", + "old_messages": ["I have a car.", "My car is a Toyota Camry."], + "incoming_message": "I own a blue Toyota Camry.", + "expected_decision": "add", + "reason": "The new message adds a specific detail (color) not covered by the old messages.", + }, +] + + +async def evaluate_decision(memory_module, test_case): + """Evaluate a single decision test case.""" + conversation_id = str(uuid4()) + + # Add old messages + for message_content in test_case["old_messages"]: + message = UserMessageInput( + id=str(uuid4()), + content=message_content, + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now() - timedelta(days=1), + ) + await memory_module.add_message(message) + + await memory_module.message_queue.message_buffer.scheduler.flush() + + # Create incoming message + new_message = [ + UserMessageInput( + id=str(uuid4()), + content=test_case["incoming_message"], + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now(), + ) + ] + + # Get the decision + extraction = await memory_module.memory_core._extract_semantic_fact_from_messages(new_message) + if not (extraction.action == "add" and extraction.facts): + return { + "success": False, + "error": "Failed to extract semantic facts", + "test_case": test_case, + "expected": test_case["expected_decision"], + "got": "failed_extraction", + "reason": "Failed to extract semantic facts", + } + + for fact in extraction.facts: + decision = await memory_module.memory_core._get_add_memory_processing_decision(fact.text, "user-123") + return { + "success": decision.decision == test_case["expected_decision"], + "expected": test_case["expected_decision"], + "got": decision.decision, + "reason": decision.reason_for_decision, + "test_case": test_case, + } + + +async def main(): + # Initialize config and memory module + llm_config = build_llm_config() + if not llm_config.api_key: + print("Error: OpenAI API key not provided") + sys.exit(1) + + db_path = Path(__file__).parent / "data" / "evaluation" / "memory_module.db" + # Create db directory if it doesn't exist + db_path.parent.mkdir(parents=True, exist_ok=True) + config = MemoryModuleConfig( + db_path=db_path, + buffer_size=5, + timeout_seconds=60, + llm=llm_config, + ) + + # Delete existing db if it exists + if db_path.exists(): + db_path.unlink() + + memory_module = MemoryModule(config=config) + + results = [] + successes = 0 + failures = 0 + + # Run evaluations with progress bar + print("\nEvaluating memory processing decisions...") + for test_case in tqdm(TEST_CASES, desc="Processing test cases"): + result = await evaluate_decision(memory_module, test_case) + results.append(result) + if result["success"]: + successes += 1 + else: + failures += 1 + + # Calculate statistics + total = len(TEST_CASES) + success_rate = (successes / total) * 100 + + # Print summary + print("\n=== Evaluation Summary ===") + print(f"Total test cases: {total}") + print(f"Successes: {successes} ({success_rate:.1f}%)") + print(f"Failures: {failures} ({100 - success_rate:.1f}%)") + + # Print detailed failures if any + if failures > 0: + print("\n=== Failed Cases ===") + for result in results: + if not result["success"]: + test_case = result["test_case"] + print(f"\nTest Case: {test_case['title']}") + print(f"Reason: {test_case['reason']}") + print(f"Actual result: {result['reason']}") + print(f"Expected: {result['expected']}") + print(f"Got: {result['got']}") + print("Old messages:") + for msg in test_case["old_messages"]: + print(f" - {msg}") + print(f"New message: {test_case['incoming_message']}") + print("-" * 50) + + # Cleanup + message_queue = cast(MessageQueue, memory_module.message_queue) + message_buffer = cast(MessageBuffer, message_queue.message_buffer) + scheduler = cast(Scheduler, message_buffer.scheduler) + await scheduler.cleanup() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/memory_module/test_memory_module.py b/tests/memory_module/test_memory_module.py index 42a30d1c..d330d6fd 100644 --- a/tests/memory_module/test_memory_module.py +++ b/tests/memory_module/test_memory_module.py @@ -116,8 +116,10 @@ async def _mock_embedding(**kwargs): @pytest_asyncio.fixture(autouse=True) async def cleanup_scheduled_events(memory_module): """Fixture to cleanup scheduled events after each test.""" - yield - await memory_module.message_queue.message_buffer.scheduler.cleanup() + try: + yield + finally: + await memory_module.message_queue.message_buffer.scheduler.cleanup() @pytest.mark.asyncio @@ -342,7 +344,7 @@ async def _validate_decision(memory_module, message: List[UserMessageInput], exp assert extraction.action == "add" and extraction.facts for fact in extraction.facts: decision = await memory_module.memory_core._get_add_memory_processing_decision(fact.text, "user-123") - assert decision == expected_decision + assert decision.decision == expected_decision @pytest.mark.asyncio From 1acb9d208c587aa1db33b96fa4ce225c787c4be3 Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Thu, 9 Jan 2025 23:25:56 -0800 Subject: [PATCH 03/11] Add tests --- scripts/evaluate_memory_decisions.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/scripts/evaluate_memory_decisions.py b/scripts/evaluate_memory_decisions.py index 61589f3a..e34afe64 100644 --- a/scripts/evaluate_memory_decisions.py +++ b/scripts/evaluate_memory_decisions.py @@ -6,8 +6,6 @@ from uuid import uuid4 from memory_module import MemoryModule, MemoryModuleConfig, UserMessageInput -from memory_module.core.message_buffer import MessageBuffer -from memory_module.core.message_queue import MessageQueue from memory_module.core.scheduler import Scheduler from tqdm import tqdm @@ -204,10 +202,7 @@ async def main(): print("-" * 50) # Cleanup - message_queue = cast(MessageQueue, memory_module.message_queue) - message_buffer = cast(MessageBuffer, message_queue.message_buffer) - scheduler = cast(Scheduler, message_buffer.scheduler) - await scheduler.cleanup() + await cast(Scheduler, memory_module.message_queue.message_buffer.scheduler).cleanup() if __name__ == "__main__": From 97e44fee303566a82a3611d1a87da9bbad06e1d5 Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Thu, 9 Jan 2025 23:30:48 -0800 Subject: [PATCH 04/11] Fix --- scripts/evaluate_memory_decisions.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/scripts/evaluate_memory_decisions.py b/scripts/evaluate_memory_decisions.py index e34afe64..43d45c6f 100644 --- a/scripts/evaluate_memory_decisions.py +++ b/scripts/evaluate_memory_decisions.py @@ -5,11 +5,15 @@ from typing import cast from uuid import uuid4 -from memory_module import MemoryModule, MemoryModuleConfig, UserMessageInput -from memory_module.core.scheduler import Scheduler from tqdm import tqdm -from tests.memory_module.utils import build_llm_config +root_dir = Path(__file__).parent.parent +sys.path.extend([str(root_dir), str(root_dir / "packages")]) + +from memory_module import MemoryModule, MemoryModuleConfig, UserMessageInput # noqa: E402 +from memory_module.services.scheduled_events_service import ScheduledEventsService # noqa: E402 + +from tests.memory_module.utils import build_llm_config # noqa: E402 # Test cases from before TEST_CASES = [ @@ -202,7 +206,7 @@ async def main(): print("-" * 50) # Cleanup - await cast(Scheduler, memory_module.message_queue.message_buffer.scheduler).cleanup() + await cast(ScheduledEventsService, memory_module.message_queue.message_buffer.scheduler).cleanup() if __name__ == "__main__": From 83d0e9ade5f95eeeb0da6743ff479e4c2b6004fb Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Fri, 10 Jan 2025 09:18:11 -0800 Subject: [PATCH 05/11] Topics --- .pre-commit-config.yaml | 2 + packages/evals/benchmark_memory_module.py | 8 +- packages/memory_module/__init__.py | 2 + packages/memory_module/config.py | 17 + packages/memory_module/core/memory_core.py | 59 +++- packages/memory_module/core/memory_module.py | 18 +- .../interfaces/base_memory_core.py | 14 +- .../interfaces/base_memory_module.py | 14 +- .../interfaces/base_memory_storage.py | 10 +- packages/memory_module/interfaces/types.py | 22 +- .../storage/in_memory_storage.py | 67 +++- .../migrations/14_add_topic_to_memories.sql | 1 + .../storage/sqlite_memory_storage.py | 308 ++++++++---------- src/tech_assistant_agent/tools.py | 4 +- tests/memory_module/test_in_memory_storage.py | 0 tests/memory_module/test_memory_module.py | 235 +++++++++++-- tests/memory_module/test_memory_storage.py | 175 +++++++++- 17 files changed, 705 insertions(+), 251 deletions(-) create mode 100644 packages/memory_module/storage/migrations/14_add_topic_to_memories.sql delete mode 100644 tests/memory_module/test_in_memory_storage.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a3744108..08f3f5d9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,3 +13,5 @@ repos: rev: v1.13.0 hooks: - id: mypy + pass_filenames: false + args: [--ignore-missing-imports, --show-traceback] diff --git a/packages/evals/benchmark_memory_module.py b/packages/evals/benchmark_memory_module.py index fb91fd11..a93a2740 100644 --- a/packages/evals/benchmark_memory_module.py +++ b/packages/evals/benchmark_memory_module.py @@ -16,9 +16,9 @@ from memory_module.config import LLMConfig, MemoryModuleConfig from memory_module.core.memory_module import MemoryModule -from memory_module.interfaces.types import AssistantMessage, UserMessage +from memory_module.interfaces.types import AssistantMessage, RetrievalConfig, UserMessage -from evals.helpers import Dataset, DatasetItem, load_dataset, setup_mlflow +from evals.helpers import Dataset, DatasetItem, SessionMessage, load_dataset, setup_mlflow from evals.metrics import string_check_metric setup_mlflow(experiment_name="memory_module") @@ -48,7 +48,7 @@ def __exit__(self, exc_type, exc_value, traceback): os.remove(self._db_path) -async def add_messages(memory_module: MemoryModule, messages: List[dict]): +async def add_messages(memory_module: MemoryModule, messages: List[SessionMessage]): def create_message(**kwargs): params = { "id": str(uuid.uuid4()), @@ -94,7 +94,7 @@ async def benchmark_memory_module(input: DatasetItem): # buffer size has to be the same as the session length to trigger sm processing with MemoryModuleManager(buffer_size=len(session)) as memory_module: await add_messages(memory_module, messages=session) - memories = await memory_module.retrieve_memories(query, user_id=None, limit=None) + memories = await memory_module.retrieve_memories(None, RetrievalConfig(query=query, limit=None)) return { "input": { diff --git a/packages/memory_module/__init__.py b/packages/memory_module/__init__.py index cc9d2ae3..5fd170e1 100644 --- a/packages/memory_module/__init__.py +++ b/packages/memory_module/__init__.py @@ -9,6 +9,7 @@ Memory, Message, MessageInput, + RetrievalConfig, ShortTermMemoryRetrievalConfig, UserMessage, UserMessageInput, @@ -29,6 +30,7 @@ "MessageInput", "AssistantMessage", "AssistantMessageInput", + "RetrievalConfig", "ShortTermMemoryRetrievalConfig", "MemoryMiddleware", ] diff --git a/packages/memory_module/config.py b/packages/memory_module/config.py index e3494ed9..0cd2d83c 100644 --- a/packages/memory_module/config.py +++ b/packages/memory_module/config.py @@ -3,6 +3,8 @@ from pydantic import BaseModel, ConfigDict, Field +from memory_module.interfaces.types import Topic + class LLMConfig(BaseModel): """Configuration for LLM service.""" @@ -16,6 +18,18 @@ class LLMConfig(BaseModel): embedding_model: Optional[str] = None +DEFAULT_TOPICS = [ + Topic( + name="General Interests and Preferences", + description="When a user mentions specific events or actions, focus on the underlying interests, hobbies, or preferences they reveal (e.g., if the user mentions attending a conference, focus on the topic of the conference, not the date or location).", # noqa: E501 + ), + Topic( + name="General Facts about the user", + description="Facts that describe relevant information about the user, such as details about where they live or things they own.", # noqa: E501 + ), +] + + class MemoryModuleConfig(BaseModel): """Configuration for memory module components. @@ -35,4 +49,7 @@ class MemoryModuleConfig(BaseModel): description="Seconds to wait before processing a conversation", ) llm: LLMConfig = Field(description="LLM service configuration") + topics: list[Topic] = Field( + default=DEFAULT_TOPICS, description="List of topics that the memory module should listen to", min_length=1 + ) enable_logging: bool = Field(default=False, description="Enable verbose logging for memory module") diff --git a/packages/memory_module/core/memory_core.py b/packages/memory_module/core/memory_core.py index 9c12912a..9d9ad301 100644 --- a/packages/memory_module/core/memory_core.py +++ b/packages/memory_module/core/memory_core.py @@ -10,12 +10,14 @@ from memory_module.interfaces.base_memory_storage import BaseMemoryStorage from memory_module.interfaces.types import ( BaseMemoryInput, - EmbedText, Memory, MemoryType, Message, MessageInput, + RetrievalConfig, ShortTermMemoryRetrievalConfig, + TextEmbedding, + Topic, ) from memory_module.services.llm_service import LLMService from memory_module.storage.in_memory_storage import InMemoryStorage @@ -51,6 +53,11 @@ class SemanticFact(BaseModel): default_factory=set, description="The indices of the messages that the fact was extracted from.", ) + # TODO: Add a validator to ensure that topics are valid + topics: Optional[List[str]] = Field( + default=None, + description="The name of the topic that the fact is most relevant to.", # noqa: E501 + ) class SemanticMemoryExtraction(BaseModel): @@ -106,6 +113,7 @@ def __init__( self.storage: BaseMemoryStorage = storage or ( SQLiteMemoryStorage(db_path=config.db_path) if config.db_path is not None else InMemoryStorage() ) + self.topics = config.topics async def process_semantic_messages( self, @@ -145,6 +153,7 @@ async def process_semantic_messages( user_id=author_id, message_attributions=list(message_ids), memory_type=MemoryType.SEMANTIC, + topics=fact.topics, ) embed_vectors = await self._get_semantic_fact_embeddings(fact.text, metadata) await self.storage.store_memory(memory, embedding_vectors=embed_vectors) @@ -154,7 +163,22 @@ async def process_episodic_messages(self, messages: List[Message]) -> None: # TODO: Implement episodic memory processing await self._extract_episodic_memory_from_messages(messages) - async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]: + async def retrieve_memories( + self, + user_id: Optional[str], + config: RetrievalConfig, + ) -> List[Memory]: + return await self._retrieve_memories( + user_id, config.query, config.topics if config.topics else None, config.limit + ) + + async def _retrieve_memories( + self, + user_id: Optional[str], + query: Optional[str], + topics: Optional[List[Topic]], + limit: Optional[int], + ) -> List[Memory]: """Retrieve memories based on a query. Steps: @@ -162,8 +186,14 @@ async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Opt 2. Find relevant memories 3. Possibly rerank or filter results """ - embedText = EmbedText(text=query, embedding_vector=await self._get_query_embedding(query)) - return await self.storage.retrieve_memories(embedText, user_id, limit) + if query: + text_embedding = TextEmbedding(text=query, embedding_vector=await self._get_query_embedding(query)) + else: + text_embedding = None + + return await self.storage.retrieve_memories( + user_id=user_id, text_embedding=text_embedding, topics=topics, limit=limit + ) async def update_memory(self, memory_id: str, updated_memory: str) -> None: metadata = await self._extract_metadata_from_fact(updated_memory) @@ -195,10 +225,15 @@ async def remove_messages(self, message_ids: List[str]) -> None: logger.info("messages {} are removed".format(",".join(message_ids))) async def _get_add_memory_processing_decision( - self, new_memory_fact: str, user_id: Optional[str] + self, new_memory_fact: SemanticFact, user_id: Optional[str] ) -> ProcessSemanticMemoryDecision: - similar_memories = await self.retrieve_memories(new_memory_fact, user_id, None) - decision = await self._extract_memory_processing_decision(new_memory_fact, similar_memories, user_id) + topics = ( + [topic for topic in self.topics if topic.name in new_memory_fact.topics] if new_memory_fact.topics else None + ) + similar_memories = await self.retrieve_memories( + user_id, RetrievalConfig(query=new_memory_fact.text, topics=topics, limit=None) + ) + decision = await self._extract_memory_processing_decision(new_memory_fact.text, similar_memories, user_id) return decision async def _extract_memory_processing_decision( @@ -306,6 +341,9 @@ async def _extract_semantic_fact_from_messages( else: # we explicitly ignore internal messages continue + topics = "\n".join( + [f"{topic.description}" for topic in self.topics] + ) existing_memories_str = "" if existing_memories: @@ -318,11 +356,7 @@ async def _extract_semantic_fact_from_messages( that will remain relevant over time, even if the user is mentioning short-term plans or events. Prioritize: -- General Interests and Preferences: When a user mentions specific events or actions, focus on the underlying -interests, hobbies, or preferences they reveal (e.g., if the user mentions attending a conference, focus on the topic of the conference, -not the date or location). -- Facts or Details about user: Extract facts that describe relevant information about the user, such as details about things they own. -- Facts about the user that the assistant might find useful. +{topics} Avoid: - Extraction memories that already exist in the system. If a fact is already stored, ignore it. @@ -335,7 +369,6 @@ async def _extract_semantic_fact_from_messages( {messages_str} """ # noqa: E501 - llm_messages = [ {"role": "system", "content": system_message}, { diff --git a/packages/memory_module/core/memory_module.py b/packages/memory_module/core/memory_module.py index 8163cd8f..e4e86aa8 100644 --- a/packages/memory_module/core/memory_module.py +++ b/packages/memory_module/core/memory_module.py @@ -7,7 +7,13 @@ from memory_module.interfaces.base_memory_core import BaseMemoryCore from memory_module.interfaces.base_memory_module import BaseMemoryModule from memory_module.interfaces.base_message_queue import BaseMessageQueue -from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig +from memory_module.interfaces.types import ( + Memory, + Message, + MessageInput, + RetrievalConfig, + ShortTermMemoryRetrievalConfig, +) from memory_module.services.llm_service import LLMService from memory_module.utils.logging import configure_logging @@ -50,10 +56,14 @@ async def add_message(self, message: MessageInput) -> Message: await self.message_queue.enqueue(message_res) return message_res - async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]: + async def retrieve_memories( + self, + user_id: Optional[str], + config: RetrievalConfig, + ) -> List[Memory]: """Retrieve relevant memories based on a query.""" - logger.debug(f"retrieve memories from (query: {query}, user_id: {user_id}, limit: {limit})") - memories = await self.memory_core.retrieve_memories(query, user_id, limit) + logger.debug(f"retrieve memories from (query: {config.query}, user_id: {user_id}, limit: {config.limit})") + memories = await self.memory_core.retrieve_memories(user_id=user_id, config=config) logger.debug(f"retrieved memories: {memories}") return memories diff --git a/packages/memory_module/interfaces/base_memory_core.py b/packages/memory_module/interfaces/base_memory_core.py index 8efc02da..201c2713 100644 --- a/packages/memory_module/interfaces/base_memory_core.py +++ b/packages/memory_module/interfaces/base_memory_core.py @@ -1,7 +1,13 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional -from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig +from memory_module.interfaces.types import ( + Memory, + Message, + MessageInput, + RetrievalConfig, + ShortTermMemoryRetrievalConfig, +) class BaseMemoryCore(ABC): @@ -22,7 +28,11 @@ async def process_episodic_messages(self, messages: List[Message]) -> None: pass @abstractmethod - async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]: + async def retrieve_memories( + self, + user_id: Optional[str], + config: RetrievalConfig, + ) -> List[Memory]: """Retrieve memories based on a query.""" pass diff --git a/packages/memory_module/interfaces/base_memory_module.py b/packages/memory_module/interfaces/base_memory_module.py index fb8890e4..69902d69 100644 --- a/packages/memory_module/interfaces/base_memory_module.py +++ b/packages/memory_module/interfaces/base_memory_module.py @@ -1,7 +1,13 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional -from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig +from memory_module.interfaces.types import ( + Memory, + Message, + MessageInput, + RetrievalConfig, + ShortTermMemoryRetrievalConfig, +) class BaseMemoryModule(ABC): @@ -13,7 +19,11 @@ async def add_message(self, message: MessageInput) -> Message: pass @abstractmethod - async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]: + async def retrieve_memories( + self, + user_id: Optional[str], + config: RetrievalConfig, + ) -> List[Memory]: """Retrieve relevant memories based on a query.""" pass diff --git a/packages/memory_module/interfaces/base_memory_storage.py b/packages/memory_module/interfaces/base_memory_storage.py index 352baa21..75036c78 100644 --- a/packages/memory_module/interfaces/base_memory_storage.py +++ b/packages/memory_module/interfaces/base_memory_storage.py @@ -3,11 +3,12 @@ from memory_module.interfaces.types import ( BaseMemoryInput, - EmbedText, Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig, + TextEmbedding, + Topic, ) @@ -47,7 +48,12 @@ async def store_short_term_memory(self, message: MessageInput) -> Message: @abstractmethod async def retrieve_memories( - self, embedText: EmbedText, user_id: Optional[str], limit: Optional[int] = None + self, + *, + user_id: Optional[str], + text_embedding: Optional[TextEmbedding] = None, + topics: Optional[List[Topic]] = None, + limit: Optional[int] = None, ) -> List[Memory]: """Retrieve memories based on a query. diff --git a/packages/memory_module/interfaces/types.py b/packages/memory_module/interfaces/types.py index d7bd4ad3..c3821f68 100644 --- a/packages/memory_module/interfaces/types.py +++ b/packages/memory_module/interfaces/types.py @@ -104,6 +104,12 @@ class BaseMemoryInput(BaseModel): memory_type: MemoryType user_id: Optional[str] = None message_attributions: Optional[List[str]] = Field(default=[]) + topics: Optional[List[str]] = None + + +class Topic(BaseModel): + name: str = Field(description="A unique name of the topic that the memory module should listen to") + description: str = Field(description="Description of the topic") class Memory(BaseMemoryInput): @@ -112,12 +118,24 @@ class Memory(BaseMemoryInput): id: str -class EmbedText(BaseModel): +class TextEmbedding(BaseModel): text: str embedding_vector: List[float] -class ShortTermMemoryRetrievalConfig(BaseModel): +class RetrievalConfig(BaseModel): + query: Optional[str] = None + topic: Optional[Topic] = None + limit: Optional[int] = None + + @model_validator(mode="after") + def check_parameters(self) -> "RetrievalConfig": + if self.query is None and self.topic is None: + raise ValueError("Either query or topic must be provided") + return self + + +class ShortTermMemoryRetrievalConfig(RetrievalConfig): n_messages: Optional[int] = None # Number of messages to retrieve last_minutes: Optional[float] = None # Time frame in minutes before: Optional[datetime] = None # Retrieve messages up until a specific timestamp diff --git a/packages/memory_module/storage/in_memory_storage.py b/packages/memory_module/storage/in_memory_storage.py index 12c3cae6..6e8024d2 100644 --- a/packages/memory_module/storage/in_memory_storage.py +++ b/packages/memory_module/storage/in_memory_storage.py @@ -14,13 +14,14 @@ AssistantMessage, AssistantMessageInput, BaseMemoryInput, - EmbedText, InternalMessage, InternalMessageInput, Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig, + TextEmbedding, + Topic, UserMessage, UserMessageInput, ) @@ -112,26 +113,54 @@ async def store_short_term_memory(self, message: MessageInput) -> Message: return message_obj async def retrieve_memories( - self, embedText: EmbedText, user_id: Optional[str], limit: Optional[int] = None + self, + *, + user_id: Optional[str], + text_embedding: Optional[TextEmbedding] = None, + topics: Optional[List[Topic]] = None, + limit: Optional[int] = None, ) -> List[Memory]: limit = limit or self.default_limit - sorted_memories: list[_MemorySimilarity] = [] + memories = [] + + # Filter memories by user_id and topics first + filtered_memories = list(self.storage["memories"].values()) + if user_id: + filtered_memories = [m for m in filtered_memories if m.user_id == user_id] + if topics: + filtered_memories = [ + m for m in filtered_memories if m.topics and any(topic.name in m.topics for topic in topics) + ] - for memory_id, embeddings in self.storage["embeddings"].items(): - memory = self.storage["memories"][memory_id] - if user_id and memory.user_id != user_id: - continue + # If we have text_embedding, calculate similarities and sort + if text_embedding: + sorted_memories: list[_MemorySimilarity] = [] - # Find the embedding with highest similarity (lowest distance) - best_similarity = float("-inf") - for embedding in embeddings: - similarity = self._cosine_similarity(embedText.embedding_vector, embedding) - best_similarity = max(best_similarity, similarity) + for memory in filtered_memories: + embeddings = self.storage["embeddings"].get(memory.id, []) + if not embeddings: + continue - sorted_memories.append(_MemorySimilarity(memory, best_similarity)) + # Find the embedding with lowest distance + best_distance = float("inf") + for embedding in embeddings: + distance = self.l2_distance(text_embedding.embedding_vector, embedding) + best_distance = min(best_distance, distance) - sorted_memories.sort(key=lambda x: x.similarity, reverse=True) - return [Memory(**item.memory.__dict__) for item in sorted_memories[:limit]] + # Filter based on distance threshold + if best_distance > 1.0: # adjust threshold as needed + continue + + sorted_memories.append(_MemorySimilarity(memory, best_distance)) + + # Sort by distance (ascending instead of descending) + sorted_memories.sort(key=lambda x: x.similarity) + memories = [Memory(**item.memory.__dict__) for item in sorted_memories[:limit]] + else: + # If no embedding, sort by created_at + memories = sorted(filtered_memories, key=lambda x: x.created_at, reverse=True)[:limit] + + return memories async def get_memories(self, memory_ids: List[str]) -> List[Memory]: return [ @@ -169,8 +198,12 @@ async def remove_memories(self, memory_ids: List[str]) -> None: self.storage["embeddings"].pop(memory_id, None) self.storage["memories"].pop(memory_id, None) - def _cosine_similarity(self, memory_vector: List[float], query_vector: List[float]) -> float: - return np.dot(np.array(query_vector), np.array(memory_vector)) + def l2_distance(self, memory_vector: List[float], query_vector: List[float]) -> float: + memory_array = np.array(memory_vector) + query_array = np.array(query_vector) + + # Compute L2 (Euclidean) distance: sqrt(sum((a-b)^2)) + return np.sqrt(np.sum((memory_array - query_array) ** 2)) async def clear_memories(self, user_id: str) -> None: memory_ids_for_user = [ diff --git a/packages/memory_module/storage/migrations/14_add_topic_to_memories.sql b/packages/memory_module/storage/migrations/14_add_topic_to_memories.sql new file mode 100644 index 00000000..511d49f6 --- /dev/null +++ b/packages/memory_module/storage/migrations/14_add_topic_to_memories.sql @@ -0,0 +1 @@ +ALTER TABLE memories ADD COLUMN topics TEXT; \ No newline at end of file diff --git a/packages/memory_module/storage/sqlite_memory_storage.py b/packages/memory_module/storage/sqlite_memory_storage.py index eef8a445..d7e92b07 100644 --- a/packages/memory_module/storage/sqlite_memory_storage.py +++ b/packages/memory_module/storage/sqlite_memory_storage.py @@ -8,12 +8,13 @@ from memory_module.interfaces.base_memory_storage import BaseMemoryStorage from memory_module.interfaces.types import ( BaseMemoryInput, - EmbedText, InternalMessageInput, Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig, + TextEmbedding, + Topic, ) from memory_module.storage.sqlite_storage import SQLiteStorage from memory_module.storage.utils import build_message_from_dict @@ -40,16 +41,20 @@ async def store_memory(self, memory: BaseMemoryInput, *, embedding_vectors: List async with self.storage.transaction() as cursor: # Store the memory + # Convert topics list to comma-separated string if it's a list + topics_str = ",".join(memory.topics) if isinstance(memory.topics, list) else memory.topics + await cursor.execute( """INSERT INTO memories - (id, content, created_at, user_id, memory_type) - VALUES (?, ?, ?, ?, ?)""", + (id, content, created_at, user_id, memory_type, topics) + VALUES (?, ?, ?, ?, ?, ?)""", ( memory_id, memory.content, memory.created_at.astimezone(datetime.timezone.utc), memory.user_id, memory.memory_type.value, + topics_str, ), ) @@ -107,112 +112,112 @@ async def update_memory(self, memory_id: str, updated_memory: str, *, embedding_ ) async def retrieve_memories( - self, embedText: EmbedText, user_id: Optional[str], limit: Optional[int] = None + self, + *, + user_id: Optional[str], + text_embedding: Optional[TextEmbedding] = None, + topics: Optional[List[Topic]] = None, + limit: Optional[int] = None, ) -> List[Memory]: - """Retrieve memories based on a query.""" - query = """ - WITH ranked_memories AS ( - SELECT - e.memory_id, - distance - FROM vec_items - JOIN embeddings e ON vec_items.memory_embedding_id = e.id - WHERE vec_items.embedding MATCH ? AND K = ? AND distance < ? - ORDER BY distance ASC - ) + base_query = """ SELECT - m.id, - m.content, - m.created_at, - m.user_id, - m.memory_type, - ma.message_id, - rm.distance - FROM ranked_memories rm - JOIN memories m ON m.id = rm.memory_id + m.*, + GROUP_CONCAT(ma.message_id) as message_attributions + {distance_select} + FROM memories m LEFT JOIN memory_attributions ma ON m.id = ma.memory_id - ORDER BY rm.distance ASC + {embedding_join} + WHERE 1=1 + {topic_filter} + {user_filter} + GROUP BY m.id + {order_by} + {limit_clause} """ - rows = await self.storage.fetch_all( - query, - ( - sqlite_vec.serialize_float32(embedText.embedding_vector), - limit or self.default_limit, - 1.0, - ), + params = [] + + # Handle embedding search first since its params come first in the query + embedding_join = "" + distance_select = "" + order_by = "ORDER BY m.created_at DESC" + + if text_embedding: + embedding_join = """ + JOIN ( + SELECT + e.memory_id, + distance + FROM vec_items + JOIN embeddings e ON vec_items.memory_embedding_id = e.id + WHERE vec_items.embedding MATCH ? AND K = ? AND distance < ? + ) rm ON m.id = rm.memory_id + """ + distance_select = ", rm.distance as _distance" + order_by = "ORDER BY rm.distance ASC" + params.extend( + [sqlite_vec.serialize_float32(text_embedding.embedding_vector), limit or self.default_limit, 1.0] + ) + + # Handle topic and user filters after embedding params + topic_filter = "" + if topics: + # Create a single AND condition with multiple LIKE clauses + topic_filter = " AND (" + " OR ".join(["m.topics LIKE ?"] * len(topics)) + ")" + params.extend(f"%{t.name}%" for t in topics) + + user_filter = "" + if user_id: + user_filter = "AND m.user_id = ?" + params.append(user_id) + + # Handle limit last + limit_clause = "" + if limit and not text_embedding: # Only add LIMIT if not using vector search + limit_clause = "LIMIT ?" + params.append(limit or self.default_limit) + + query = base_query.format( + distance_select=distance_select, + embedding_join=embedding_join, + topic_filter=topic_filter, + user_filter=user_filter, + order_by=order_by, + limit_clause=limit_clause, ) - # Group rows by memory_id to handle message attributions - memories_dict = {} - for row in rows: - memory_id = row["id"] - if memory_id not in memories_dict: - memories_dict[memory_id] = { - "id": memory_id, - "content": row["content"], - "created_at": row["created_at"], - "user_id": row["user_id"], - "memory_type": row["memory_type"], - "message_attributions": [], - "distance": row["distance"], - } - - if row["message_id"]: - memories_dict[memory_id]["message_attributions"].append(row["message_id"]) - - return [Memory(**memory_data) for memory_data in memories_dict.values()] + rows = await self.storage.fetch_all(query, tuple(params)) + return [self._build_memory(row, (row["message_attributions"] or "").split(",")) for row in rows] async def clear_memories(self, user_id: str) -> None: """Clear all memories for a given user.""" query = """ SELECT - m.id, - e.id AS embed_id + m.id FROM memories m - LEFT JOIN embeddings e - WHERE m.user_id = ? AND m.id = e.memory_id + WHERE m.user_id = ? """ - id_rows = await self.storage.fetch_all(query, (user_id,)) - memory_id_list = [row["id"] for row in id_rows] - embed_id_list = [row["embed_id"] for row in id_rows] - + rows = await self.storage.fetch_all(query, (user_id,)) + memory_ids = [row["id"] for row in rows] # Remove memory - await self._remove_memories_and_embeddings(memory_id_list, ",".join(["?"] * len(embed_id_list)), embed_id_list) + await self.remove_memories(memory_ids) async def get_memory(self, memory_id: int) -> Optional[Memory]: - """Retrieve a memory with its message attributions.""" query = """ SELECT - m.id, - m.content, - m.created_at, - m.updated_at, - m.user_id, - m.memory_type, - ma.message_id + m.*, + GROUP_CONCAT(ma.message_id) as message_attributions FROM memories m LEFT JOIN memory_attributions ma ON m.id = ma.memory_id WHERE m.id = ? + GROUP BY m.id """ - rows = await self.storage.fetch_all(query, (memory_id,)) - - if not rows: + row = await self.storage.fetch_one(query, (memory_id,)) + if not row: return None - # First row contains the memory data - memory_data = { - "id": rows[0]["id"], - "content": rows[0]["content"], - "created_at": rows[0]["created_at"], - "updated_at": rows[0]["updated_at"], - "user_id": rows[0]["user_id"], - "memory_type": rows[0]["memory_type"], - "message_attributions": [row["message_id"] for row in rows if row["message_id"]], - } - - return Memory(**memory_data) + return self._build_memory(row, (row["message_attributions"] or "").split(",")) async def get_all_memories( self, limit: Optional[int] = None, message_ids: Optional[List[str]] = None @@ -220,12 +225,8 @@ async def get_all_memories( """Retrieve all memories with their message attributions.""" query = """ SELECT - m.id, - m.content, - m.created_at, - m.user_id, - m.memory_type, - ma.message_id + m.*, + GROUP_CONCAT(ma.message_id) as message_attributions FROM memories m LEFT JOIN memory_attributions ma ON m.id = ma.memory_id """ @@ -236,32 +237,17 @@ async def get_all_memories( message_ids, ) - query += " ORDER BY m.created_at DESC" + query += """ + GROUP BY m.id + ORDER BY m.created_at DESC + """ if limit is not None: query += " LIMIT ?" params += (limit,) - rows = await self.storage.fetch_all(query, tuple(params)) - - # Group rows by memory_id - memories_dict = {} - for row in rows: - memory_id = row["id"] - if memory_id not in memories_dict: - memories_dict[memory_id] = { - "id": memory_id, - "content": row["content"], - "created_at": row["created_at"], - "user_id": row["user_id"], - "memory_type": row["memory_type"], - "message_attributions": [], - } - - if row["message_id"]: - memories_dict[memory_id]["message_attributions"].append(row["message_id"]) - - return [Memory(**memory_data) for memory_data in memories_dict.values()] + rows = await self.storage.fetch_all(query, params) + return [self._build_memory(row, (row["message_attributions"] or "").split(",")) for row in rows] async def store_short_term_memory(self, message: MessageInput) -> Message: """Store a short-term memory entry.""" @@ -332,93 +318,66 @@ async def retrieve_chat_history( return [build_message_from_dict(row) for row in rows][::-1] async def get_memories(self, memory_ids: List[str]) -> List[Memory]: - query = """ + query = f""" SELECT - m.id, - m.content, - m.created_at, - m.user_id, - m.memory_type, - ma.message_id + m.*, + GROUP_CONCAT(ma.message_id) as message_attributions FROM memories m LEFT JOIN memory_attributions ma ON m.id = ma.memory_id - WHERE m.id IN ({}) - """.format(",".join(["?"] * len(memory_ids))) + WHERE m.id IN ({','.join(['?'] * len(memory_ids))}) + """ rows = await self.storage.fetch_all(query, tuple(memory_ids)) - # Group rows by memory_id - memories_dict = {} - for row in rows: - memory_id = row["id"] - if memory_id not in memories_dict: - memories_dict[memory_id] = { - "id": memory_id, - "content": row["content"], - "created_at": row["created_at"], - "user_id": row["user_id"], - "memory_type": row["memory_type"], - "message_attributions": [], - } - - if row["message_id"]: - memories_dict[memory_id]["message_attributions"].append(row["message_id"]) - - return [Memory(**memory_data) for memory_data in memories_dict.values()] + return [self._build_memory(row, (row["message_attributions"] or "").split(",")) for row in rows] async def get_user_memories(self, user_id: str) -> List[Memory]: - """Get memories based on user id.""" query = """ SELECT m.*, - ma.message_id + GROUP_CONCAT(ma.message_id) as message_attributions FROM memories m LEFT JOIN memory_attributions ma ON m.id = ma.memory_id WHERE m.user_id = ? + GROUP BY m.id """ rows = await self.storage.fetch_all(query, (user_id,)) - - # Group rows by memory_id - memories_dict = {} - for row in rows: - memory_id = row["id"] - if memory_id not in memories_dict: - memories_dict[memory_id] = { - "id": memory_id, - "content": row["content"], - "created_at": row["created_at"], - "user_id": row["user_id"], - "memory_type": row["memory_type"], - "message_attributions": [], - } - - if row["message_id"]: - memories_dict[memory_id]["message_attributions"].append(row["message_id"]) - - return [Memory(**memory_data) for memory_data in memories_dict.values()] + return [self._build_memory(row, (row["message_attributions"] or "").split(",")) for row in rows] async def get_messages(self, memory_ids: List[str]) -> Dict[str, List[Message]]: - """Get messages based on memory ids.""" - query = """ - SELECT ma.memory_id, m.* + query = f""" + SELECT + ma.memory_id as _memory_id, + m.* FROM memory_attributions ma JOIN messages m ON ma.message_id = m.id - WHERE ma.memory_id IN ({}) - """.format(",".join(["?"] * len(memory_ids))) + WHERE ma.memory_id IN ({','.join(['?'] * len(memory_ids))}) + """ rows = await self.storage.fetch_all(query, tuple(memory_ids)) messages_dict: Dict[str, List[Message]] = {} for row in rows: - memory_id = row["memory_id"] + memory_id = row["_memory_id"] if memory_id not in messages_dict: messages_dict[memory_id] = [] - - messages_dict[memory_id].append(build_message_from_dict(row)) + messages_dict[memory_id].append( + build_message_from_dict({k: v for k, v in row.items() if not k.startswith("_")}) + ) return messages_dict + def _build_memory(self, memory_values: dict, message_attributions: List[str]) -> Memory: + memory_keys = ["id", "content", "created_at", "user_id", "memory_type", "topics"] + # Convert topics string back to list if it exists + if memory_values.get("topics"): + memory_values["topics"] = memory_values["topics"].split(",") + return Memory( + **{k: v for k, v in memory_values.items() if k in memory_keys}, + message_attributions=message_attributions, + ) + async def remove_messages(self, message_ids: List[str]) -> None: async with self.storage.transaction() as cursor: await cursor.execute( @@ -427,21 +386,12 @@ async def remove_messages(self, message_ids: List[str]) -> None: ) async def remove_memories(self, memory_ids: List[str]) -> None: - query = """ - SELECT - id - FROM embeddings - WHERE memory_id in ({}) - """.format(",".join(["?"] * len(memory_ids))) - await self._remove_memories_and_embeddings(memory_ids, query) - - async def _remove_memories_and_embeddings( - self, memory_ids: List[str], embed_query: str, embed_ids: Optional[List[str]] = None - ) -> None: async with self.storage.transaction() as cursor: await cursor.execute( - "DELETE FROM vec_items WHERE memory_embedding_id in ({})".format(embed_query), - tuple(embed_ids or memory_ids), + """DELETE FROM vec_items WHERE memory_embedding_id in ( + SELECT id FROM embeddings WHERE memory_id in ({}) + )""".format(",".join(["?"] * len(memory_ids))), + tuple(memory_ids), ) await cursor.execute( diff --git a/src/tech_assistant_agent/tools.py b/src/tech_assistant_agent/tools.py index 4e3193a4..072f175f 100644 --- a/src/tech_assistant_agent/tools.py +++ b/src/tech_assistant_agent/tools.py @@ -3,7 +3,7 @@ from botbuilder.core import TurnContext from botbuilder.schema import Activity -from memory_module import BaseMemoryModule, Memory +from memory_module import BaseMemoryModule, Memory, RetrievalConfig from pydantic import BaseModel, Field from teams.ai.citations import AIEntity, Appearance, ClientCitation @@ -49,7 +49,7 @@ async def get_candidate_tasks(candidate_tasks: GetCandidateTasks) -> str: async def get_memorized_fields(memory_module: BaseMemoryModule, fields_to_retrieve: GetMemorizedFields) -> str: empty_obj: dict = {} for query in fields_to_retrieve.queries_for_fields: - result = await memory_module.retrieve_memories(query, None, None) + result = await memory_module.retrieve_memories(None, RetrievalConfig(query=query, limit=None)) print("Getting memorized queries: ", query) print(result) print("---") diff --git a/tests/memory_module/test_in_memory_storage.py b/tests/memory_module/test_in_memory_storage.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/memory_module/test_memory_module.py b/tests/memory_module/test_memory_module.py index d330d6fd..d37394b2 100644 --- a/tests/memory_module/test_memory_module.py +++ b/tests/memory_module/test_memory_module.py @@ -11,7 +11,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from memory_module.config import MemoryModuleConfig +from memory_module.config import DEFAULT_TOPICS, MemoryModuleConfig, Topic # noqa: I001 from memory_module.core.memory_core import ( EpisodicMemoryExtraction, MemoryCore, @@ -21,6 +21,8 @@ ) from memory_module.core.memory_module import MemoryModule from memory_module.interfaces.types import ( + AssistantMessageInput, + RetrievalConfig, ShortTermMemoryRetrievalConfig, UserMessageInput, ) @@ -31,16 +33,21 @@ @pytest.fixture -def config(): +def config(request): """Fixture to create test config.""" + params = request.param if hasattr(request, "param") else {} llm_config = build_llm_config() + buffer_size = params.get("buffer_size", 5) + timeout_seconds = params.get("timeout_seconds", 60) + topics = params.get("topics", DEFAULT_TOPICS) if not llm_config.api_key: pytest.skip("OpenAI API key not provided") return MemoryModuleConfig( db_path=Path(__file__).parent / "data" / "tests" / "memory_module.db", - buffer_size=5, - timeout_seconds=60, + buffer_size=buffer_size, + timeout_seconds=timeout_seconds, llm=llm_config, + topics=topics, ) @@ -148,12 +155,12 @@ async def test_simple_conversation(memory_module): await memory_module.message_queue.message_buffer.scheduler.flush() stored_memories = await memory_module.memory_core.storage.get_all_memories() - assert len(stored_memories) == 2 + assert len(stored_memories) >= 1 assert any("pie" in message.content for message in stored_memories) assert any(message.id in stored_memories[0].message_attributions for message in messages) assert all(memory.memory_type == "semantic" for memory in stored_memories) - result = await memory_module.retrieve_memories("apple pie", "", 1) + result = await memory_module.retrieve_memories("user-123", RetrievalConfig(query="apple pie", limit=1)) assert len(result) == 1 assert result[0].id == next(memory.id for memory in stored_memories if "apple pie" in memory.content) @@ -285,6 +292,17 @@ async def test_short_term_memory(memory_module): @pytest.mark.asyncio async def test_add_memory_processing_decision(memory_module): """Test whether to process adding memory""" + + async def _validate_decision(memory_module, message: List[UserMessageInput], expected_decision: str): + extraction = await memory_module.memory_core._extract_semantic_fact_from_messages(message) + assert extraction.action == "add" and extraction.facts + for fact in extraction.facts: + decision = await memory_module.memory_core._get_add_memory_processing_decision(fact, "user-123") + if decision != expected_decision: + # Adding this because this test is flaky and it would be good to know why. + print(f"Decision: {decision}, Expected: {expected_decision}", fact, decision) + assert decision == expected_decision + conversation_id = str(uuid4()) old_messages = [ UserMessageInput( @@ -292,21 +310,21 @@ async def test_add_memory_processing_decision(memory_module): content="I have a Pokemon limited version Mac book.", author_id="user-123", conversation_ref=conversation_id, - created_at=datetime.strptime("2024-09-01", "%Y-%m-%d"), + created_at=datetime.now() - timedelta(minutes=3), ), UserMessageInput( id=str(uuid4()), content="I bought a pink iphone.", author_id="user-123", conversation_ref=conversation_id, - created_at=datetime.strptime("2024-09-03", "%Y-%m-%d"), + created_at=datetime.now() - timedelta(minutes=2), ), UserMessageInput( id=str(uuid4()), - content="I just had another Mac book.", + content="I just bought a Mac book.", author_id="user-123", conversation_ref=conversation_id, - created_at=datetime.strptime("2024-10-12", "%Y-%m-%d"), + created_at=datetime.now() - timedelta(minutes=1), ), ] new_messages = [ @@ -322,7 +340,7 @@ async def test_add_memory_processing_decision(memory_module): [ UserMessageInput( id=str(uuid4()), - content="I bought one more new Mac book", + content="I got a new cat!", author_id="user-123", conversation_ref=conversation_id, created_at=datetime.now(), @@ -339,14 +357,6 @@ async def test_add_memory_processing_decision(memory_module): await _validate_decision(memory_module, new_messages[1], "add") -async def _validate_decision(memory_module, message: List[UserMessageInput], expected_decision: str): - extraction = await memory_module.memory_core._extract_semantic_fact_from_messages(message) - assert extraction.action == "add" and extraction.facts - for fact in extraction.facts: - decision = await memory_module.memory_core._get_add_memory_processing_decision(fact.text, "user-123") - assert decision.decision == expected_decision - - @pytest.mark.asyncio async def test_remove_messages(memory_module): conversation1_id = str(uuid4()) @@ -418,3 +428,190 @@ async def test_remove_messages(memory_module): conversation_refs = list(updated_buffer.keys()) assert len(conversation_refs) == 1 assert conversation_refs[0] == conversation3_id + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "config", + [ + { + "topics": [ + Topic(name="Device Type", description="The type of device the user has"), + Topic(name="Operating System", description="The user's operating system"), + Topic(name="Device year", description="The year of the user's device"), + ], + "buffer_size": 10, + } + ], + indirect=True, +) +async def test_topic_extraction(memory_module): + conversation_id = str(uuid4()) + messages = [ + {"role": "user", "content": "I need help with my device..."}, + {"role": "assistant", "content": "I'm sorry to hear that. What device do you have?"}, + {"role": "user", "content": "I have a Macbook"}, + {"role": "assistant", "content": "What is the year of your device?"}, + {"role": "user", "content": "2024"}, + ] + + for message in messages: + if message["role"] == "user": + input = UserMessageInput( + id=str(uuid4()), + content=message["content"], + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now(), + ) + else: + input = AssistantMessageInput( + id=str(uuid4()), + content=message["content"], + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now(), + ) + await memory_module.add_message(input) + + await memory_module.message_queue.message_buffer.scheduler.flush() + stored_memories = await memory_module.memory_core.storage.get_all_memories() + assert any("macbook" in message.content.lower() for message in stored_memories) + assert any("2024" in message.content for message in stored_memories) + + # Add assertions for topics + device_type_memory = next((m for m in stored_memories if "macbook" in m.content.lower()), None) + year_memory = next((m for m in stored_memories if "2024" in m.content), None) + + assert device_type_memory is not None and "Device Type" in device_type_memory.topics + assert year_memory is not None and "Device year" in year_memory.topics + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "config", + [ + { + "topics": [ + Topic(name="Device Type", description="The type of device the user has"), + Topic(name="Operating System", description="The user's operating system"), + Topic(name="Device year", description="The year of the user's device"), + ], + "buffer_size": 10, + } + ], + indirect=True, +) +async def test_retrieve_memories_by_topic(memory_module): + """Test retrieving memories by topic only.""" + conversation_id = str(uuid4()) + messages = [ + UserMessageInput( + id=str(uuid4()), + content="I use Windows 11 on my PC", + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now() - timedelta(minutes=5), + ), + UserMessageInput( + id=str(uuid4()), + content="I have a MacBook Pro from 2023", + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now() - timedelta(minutes=3), + ), + UserMessageInput( + id=str(uuid4()), + content="My MacBook runs macOS Sonoma", + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now() - timedelta(minutes=1), + ), + ] + + for message in messages: + await memory_module.add_message(message) + await memory_module.message_queue.message_buffer.scheduler.flush() + + # Retrieve memories by Operating System topic + os_memories = await memory_module.retrieve_memories( + "user-123", + RetrievalConfig(topic=Topic(name="Operating System", description="The user's operating system")), + ) + + assert all("Operating System" in memory.topics for memory in os_memories) + assert any("windows 11" in memory.content.lower() for memory in os_memories) + assert any("sonoma" in memory.content.lower() for memory in os_memories) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "config", + [ + { + "topics": [ + Topic(name="Device Type", description="The type of device the user has"), + Topic(name="Operating System", description="The user's operating system"), + Topic(name="Device year", description="The year of the user's device"), + ], + "buffer_size": 10, + } + ], + indirect=True, +) +async def test_retrieve_memories_by_topic_and_query(memory_module): + """Test retrieving memories using both topic and semantic search.""" + conversation_id = str(uuid4()) + messages = [ + UserMessageInput( + id=str(uuid4()), + content="I use Windows 11 on my gaming PC", + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now() - timedelta(minutes=5), + ), + UserMessageInput( + id=str(uuid4()), + content="I have a MacBook Pro from 2023", + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now() - timedelta(minutes=3), + ), + UserMessageInput( + id=str(uuid4()), + content="My MacBook runs macOS Sonoma", + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now() - timedelta(minutes=1), + ), + ] + + for message in messages: + await memory_module.add_message(message) + await memory_module.message_queue.message_buffer.scheduler.flush() + + # Retrieve memories by Operating System topic AND query about Mac + memories = await memory_module.retrieve_memories( + "user-123", + RetrievalConfig( + topic=Topic(name="Operating System", description="The user's operating system"), + query="MacBook", + ), + ) + assert len(memories) == 1 + most_relevant_memory = memories[0] + assert "macOS" in most_relevant_memory.content + assert "Windows" not in most_relevant_memory.content + + # Try another query within the same topic + windows_memories = await memory_module.retrieve_memories( + "user-123", + RetrievalConfig( + topic=Topic(name="Operating System", description="The user's operating system"), + query="What is the operating system for the user's Windows PC?", + ), + ) + + most_relevant_memory = windows_memories[0] + assert "Windows" in most_relevant_memory.content + assert "macOS" not in most_relevant_memory.content diff --git a/tests/memory_module/test_memory_storage.py b/tests/memory_module/test_memory_storage.py index 0b4fe067..506c8c5b 100644 --- a/tests/memory_module/test_memory_storage.py +++ b/tests/memory_module/test_memory_storage.py @@ -6,9 +6,10 @@ from memory_module.interfaces.types import ( AssistantMessageInput, BaseMemoryInput, - EmbedText, MemoryType, ShortTermMemoryRetrievalConfig, + TextEmbedding, + Topic, UserMessageInput, ) from memory_module.storage.in_memory_storage import InMemoryStorage @@ -90,10 +91,10 @@ async def test_retrieve_memories(memory_storage, sample_memory_input, sample_emb await memory_storage.store_memory(sample_memory_input, embedding_vectors=sample_embedding) # Create query embedding - query = EmbedText(text="test query", embedding_vector=sample_embedding[0]) + query = TextEmbedding(text="test query", embedding_vector=sample_embedding[0]) # Retrieve memories - memories = await memory_storage.retrieve_memories(query, "test_user", limit=1) + memories = await memory_storage.retrieve_memories(user_id="test_user", text_embedding=query, limit=1) assert len(memories) > 0 assert memories[0].content == sample_memory_input.content @@ -109,9 +110,9 @@ async def test_retrieve_memories_multiple_embeddings(memory_storage, sample_memo await memory_storage.store_memory(sample_memory_input, embedding_vectors=embeddings) # Query should match the second embedding better - query = EmbedText(text="test query", embedding_vector=[1.0] * 1536) + query = TextEmbedding(text="test query", embedding_vector=[1.0] * 1536) - memories = await memory_storage.retrieve_memories(query, "test_user", limit=1) + memories = await memory_storage.retrieve_memories(user_id="test_user", text_embedding=query, limit=1) assert len(memories) == 1 @@ -327,3 +328,167 @@ async def test_get_messages(memory_storage): assert result[memory_id_1][0].id == "msg1" assert result[memory_id_1][1].id == "msg2" assert result[memory_id_2][0].id == "msg2" + + +@pytest.mark.asyncio +async def test_retrieve_memories_by_topic(memory_storage, sample_embedding): + # Store memories with different topics + memory1 = BaseMemoryInput( + content="Memory about AI", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["AI"], + ) + memory2 = BaseMemoryInput( + content="Memory about nature", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["nature"], + ) + + await memory_storage.store_memory(memory1, embedding_vectors=sample_embedding) + await memory_storage.store_memory(memory2, embedding_vectors=sample_embedding) + + # Retrieve memories by single topic + memories = await memory_storage.retrieve_memories( + user_id="test_user", topics=[Topic(name="AI", description="")], limit=10 + ) + assert len(memories) == 1 + assert memories[0].content == "Memory about AI" + assert "AI" in memories[0].topics + + # Test with non-existent topic + memories = await memory_storage.retrieve_memories( + user_id="test_user", topics=[Topic(name="non_existent_topic", description="")], limit=10 + ) + assert len(memories) == 0 + + +@pytest.mark.asyncio +async def test_retrieve_memories_by_topic_and_embedding(memory_storage, sample_embedding): + # Store memories with different topics + memory1 = BaseMemoryInput( + content="Technical discussion about artificial intelligence", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["AI"], + ) + memory2 = BaseMemoryInput( + content="Another AI related memory but less relevant", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["AI"], + ) + + await memory_storage.store_memory(memory1, embedding_vectors=sample_embedding) + await memory_storage.store_memory(memory2, embedding_vectors=[[0.2] * 1536]) # Less similar embedding + + # Create query embedding + query = TextEmbedding(text="AI technology", embedding_vector=sample_embedding[0]) + + # Retrieve memories using both topic and semantic similarity + memories = await memory_storage.retrieve_memories( + user_id="test_user", text_embedding=query, topics=[Topic(name="AI", description="")], limit=2 + ) + + assert len(memories) == 1 + assert memories[0].content == "Technical discussion about artificial intelligence" + + +@pytest.mark.asyncio +async def test_retrieve_memories_with_multiple_topics(memory_storage, sample_embedding): + # Store memories with multiple topics + memory1 = BaseMemoryInput( + content="Memory about AI and robotics", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["AI", "robotics"], + ) + memory2 = BaseMemoryInput( + content="Memory about AI and machine learning", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["AI", "machine learning"], + ) + memory3 = BaseMemoryInput( + content="Memory about nature", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["nature"], + ) + + await memory_storage.store_memory(memory1, embedding_vectors=sample_embedding) + await memory_storage.store_memory(memory2, embedding_vectors=sample_embedding) + await memory_storage.store_memory(memory3, embedding_vectors=sample_embedding) + + # Retrieve memories by AI topic (should get both AI-related memories) + memories = await memory_storage.retrieve_memories( + user_id="test_user", topics=[Topic(name="AI", description="")], limit=10 + ) + assert len(memories) == 2 + assert all("AI" in memory.topics for memory in memories) + + # Retrieve memories by robotics topic (should get only the robotics memory) + memories = await memory_storage.retrieve_memories( + user_id="test_user", topics=[Topic(name="robotics", description="")], limit=10 + ) + assert len(memories) == 1 + assert "robotics" in memories[0].topics + assert memories[0].content == "Memory about AI and robotics" + + +@pytest.mark.asyncio +async def test_retrieve_memories_with_multiple_topics_parameter(memory_storage, sample_embedding): + # Store memories with multiple topics + memory1 = BaseMemoryInput( + content="Memory about AI and robotics", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["AI", "robotics"], + ) + memory2 = BaseMemoryInput( + content="Memory about AI and machine learning", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["AI", "machine learning"], + ) + memory3 = BaseMemoryInput( + content="Memory about nature", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["nature"], + ) + + await memory_storage.store_memory(memory1, embedding_vectors=sample_embedding) + await memory_storage.store_memory(memory2, embedding_vectors=sample_embedding) + await memory_storage.store_memory(memory3, embedding_vectors=sample_embedding) + + # Retrieve memories by multiple topics + memories = await memory_storage.retrieve_memories( + user_id="test_user", topics=[Topic(name="AI", description=""), Topic(name="robotics", description="")], limit=10 + ) + + # Should get both AI-related memories + assert len(memories) == 2 + assert any("robotics" in memory.topics for memory in memories) + assert all("AI" in memory.topics for memory in memories) From 7e33baf9de1af1668fdb34db6342f011fa494087 Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Fri, 10 Jan 2025 09:43:54 -0800 Subject: [PATCH 06/11] Fix --- packages/memory_module/core/memory_core.py | 2 +- packages/memory_module/interfaces/types.py | 6 +++--- tests/memory_module/test_memory_module.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/packages/memory_module/core/memory_core.py b/packages/memory_module/core/memory_core.py index 9d9ad301..1cc58753 100644 --- a/packages/memory_module/core/memory_core.py +++ b/packages/memory_module/core/memory_core.py @@ -141,7 +141,7 @@ async def process_semantic_messages( if extraction.action == "add" and extraction.facts: for fact in extraction.facts: - decision = await self._get_add_memory_processing_decision(fact.text, author_id) + decision = await self._get_add_memory_processing_decision(fact, author_id) if decision.decision == "ignore": logger.info(f"Decision to ignore fact {fact.text}") continue diff --git a/packages/memory_module/interfaces/types.py b/packages/memory_module/interfaces/types.py index c3821f68..2d053ebd 100644 --- a/packages/memory_module/interfaces/types.py +++ b/packages/memory_module/interfaces/types.py @@ -125,13 +125,13 @@ class TextEmbedding(BaseModel): class RetrievalConfig(BaseModel): query: Optional[str] = None - topic: Optional[Topic] = None + topics: Optional[List[Topic]] = None limit: Optional[int] = None @model_validator(mode="after") def check_parameters(self) -> "RetrievalConfig": - if self.query is None and self.topic is None: - raise ValueError("Either query or topic must be provided") + if self.query is None and (self.topics is None or len(self.topics) == 0): + raise ValueError("Either query or topics must be provided") return self diff --git a/tests/memory_module/test_memory_module.py b/tests/memory_module/test_memory_module.py index d37394b2..d1567e3d 100644 --- a/tests/memory_module/test_memory_module.py +++ b/tests/memory_module/test_memory_module.py @@ -536,7 +536,7 @@ async def test_retrieve_memories_by_topic(memory_module): # Retrieve memories by Operating System topic os_memories = await memory_module.retrieve_memories( "user-123", - RetrievalConfig(topic=Topic(name="Operating System", description="The user's operating system")), + RetrievalConfig(topics=[Topic(name="Operating System", description="The user's operating system")]), ) assert all("Operating System" in memory.topics for memory in os_memories) @@ -594,7 +594,7 @@ async def test_retrieve_memories_by_topic_and_query(memory_module): memories = await memory_module.retrieve_memories( "user-123", RetrievalConfig( - topic=Topic(name="Operating System", description="The user's operating system"), + topics=[Topic(name="Operating System", description="The user's operating system")], query="MacBook", ), ) @@ -607,7 +607,7 @@ async def test_retrieve_memories_by_topic_and_query(memory_module): windows_memories = await memory_module.retrieve_memories( "user-123", RetrievalConfig( - topic=Topic(name="Operating System", description="The user's operating system"), + topics=[Topic(name="Operating System", description="The user's operating system")], query="What is the operating system for the user's Windows PC?", ), ) From 55fd99079718a7d13d3dae1d82bff102de93a12c Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Fri, 10 Jan 2025 10:52:43 -0800 Subject: [PATCH 07/11] Add comment --- packages/memory_module/interfaces/types.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/packages/memory_module/interfaces/types.py b/packages/memory_module/interfaces/types.py index 2d053ebd..7eacc595 100644 --- a/packages/memory_module/interfaces/types.py +++ b/packages/memory_module/interfaces/types.py @@ -124,9 +124,25 @@ class TextEmbedding(BaseModel): class RetrievalConfig(BaseModel): - query: Optional[str] = None - topics: Optional[List[Topic]] = None - limit: Optional[int] = None + """Configuration for memory retrieval operations. + + This class defines the parameters used to retrieve memories from storage. Memories can be + retrieved either by a semantic search query or by filtering for specific topics or both. + + In case of both, the memories are retrieved by the intersection of the two sets. + """ + + query: Optional[str] = Field( + default=None, description="A natural language query to search for semantically similar memories" + ) + topics: Optional[List[Topic]] = Field( + default=None, + description="List of topics to filter memories by. Only memories tagged with these topics will be retrieved", + ) + limit: Optional[int] = Field( + default=None, + description="Maximum number of memories to retrieve. If not specified, all matching memories are returned", + ) @model_validator(mode="after") def check_parameters(self) -> "RetrievalConfig": From 6b3e656c5ac7ca3dac1482e19aab099788ce8356 Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Fri, 10 Jan 2025 11:08:08 -0800 Subject: [PATCH 08/11] Fix --- packages/memory_module/core/memory_core.py | 12 +++++------- packages/memory_module/interfaces/types.py | 10 +++++----- tests/memory_module/test_memory_module.py | 12 ++++++------ 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/packages/memory_module/core/memory_core.py b/packages/memory_module/core/memory_core.py index 1cc58753..f15aff13 100644 --- a/packages/memory_module/core/memory_core.py +++ b/packages/memory_module/core/memory_core.py @@ -169,7 +169,7 @@ async def retrieve_memories( config: RetrievalConfig, ) -> List[Memory]: return await self._retrieve_memories( - user_id, config.query, config.topics if config.topics else None, config.limit + user_id, config.query, [config.topic] if config.topic else None, config.limit ) async def _retrieve_memories( @@ -227,12 +227,10 @@ async def remove_messages(self, message_ids: List[str]) -> None: async def _get_add_memory_processing_decision( self, new_memory_fact: SemanticFact, user_id: Optional[str] ) -> ProcessSemanticMemoryDecision: - topics = ( - [topic for topic in self.topics if topic.name in new_memory_fact.topics] if new_memory_fact.topics else None - ) - similar_memories = await self.retrieve_memories( - user_id, RetrievalConfig(query=new_memory_fact.text, topics=topics, limit=None) - ) + # topics = ( + # [topic for topic in self.topics if topic.name in new_memory_fact.topics] if new_memory_fact.topics else None # noqa: E501 + # ) + similar_memories = await self._retrieve_memories(user_id, new_memory_fact.text, None, None) decision = await self._extract_memory_processing_decision(new_memory_fact.text, similar_memories, user_id) return decision diff --git a/packages/memory_module/interfaces/types.py b/packages/memory_module/interfaces/types.py index 7eacc595..933a7b84 100644 --- a/packages/memory_module/interfaces/types.py +++ b/packages/memory_module/interfaces/types.py @@ -127,7 +127,7 @@ class RetrievalConfig(BaseModel): """Configuration for memory retrieval operations. This class defines the parameters used to retrieve memories from storage. Memories can be - retrieved either by a semantic search query or by filtering for specific topics or both. + retrieved either by a semantic search query or by filtering for a specific topic or both. In case of both, the memories are retrieved by the intersection of the two sets. """ @@ -135,9 +135,9 @@ class RetrievalConfig(BaseModel): query: Optional[str] = Field( default=None, description="A natural language query to search for semantically similar memories" ) - topics: Optional[List[Topic]] = Field( + topic: Optional[Topic] = Field( default=None, - description="List of topics to filter memories by. Only memories tagged with these topics will be retrieved", + description="Topic to filter memories by. Only memories tagged with this topic will be retrieved", ) limit: Optional[int] = Field( default=None, @@ -146,8 +146,8 @@ class RetrievalConfig(BaseModel): @model_validator(mode="after") def check_parameters(self) -> "RetrievalConfig": - if self.query is None and (self.topics is None or len(self.topics) == 0): - raise ValueError("Either query or topics must be provided") + if self.query is None and self.topic is None: + raise ValueError("Either query or topic must be provided") return self diff --git a/tests/memory_module/test_memory_module.py b/tests/memory_module/test_memory_module.py index 731db50e..de4be922 100644 --- a/tests/memory_module/test_memory_module.py +++ b/tests/memory_module/test_memory_module.py @@ -309,7 +309,7 @@ async def _validate_decision(memory_module, message: List[UserMessageInput], exp old_messages = [ UserMessageInput( id=str(uuid4()), - content="I have a Pokemon limited version Mac book.", + content="I have a Pokemon limited version Macbook.", author_id="user-123", conversation_ref=conversation_id, created_at=datetime.now() - timedelta(minutes=3), @@ -323,7 +323,7 @@ async def _validate_decision(memory_module, message: List[UserMessageInput], exp ), UserMessageInput( id=str(uuid4()), - content="I just bought a Mac book.", + content="I just bought a Macbook.", author_id="user-123", conversation_ref=conversation_id, created_at=datetime.now() - timedelta(minutes=1), @@ -333,7 +333,7 @@ async def _validate_decision(memory_module, message: List[UserMessageInput], exp [ UserMessageInput( id=str(uuid4()), - content="I have a Mac book", + content="I have a Macbook", author_id="user-123", conversation_ref=conversation_id, created_at=datetime.now(), @@ -538,7 +538,7 @@ async def test_retrieve_memories_by_topic(memory_module): # Retrieve memories by Operating System topic os_memories = await memory_module.retrieve_memories( "user-123", - RetrievalConfig(topics=[Topic(name="Operating System", description="The user's operating system")]), + RetrievalConfig(topic=Topic(name="Operating System", description="The user's operating system")), ) assert all("Operating System" in memory.topics for memory in os_memories) @@ -596,7 +596,7 @@ async def test_retrieve_memories_by_topic_and_query(memory_module): memories = await memory_module.retrieve_memories( "user-123", RetrievalConfig( - topics=[Topic(name="Operating System", description="The user's operating system")], + topic=Topic(name="Operating System", description="The user's operating system"), query="MacBook", ), ) @@ -609,7 +609,7 @@ async def test_retrieve_memories_by_topic_and_query(memory_module): windows_memories = await memory_module.retrieve_memories( "user-123", RetrievalConfig( - topics=[Topic(name="Operating System", description="The user's operating system")], + topic=Topic(name="Operating System", description="The user's operating system"), query="What is the operating system for the user's Windows PC?", ), ) From c8dfdeb6c11518a0f0566a73a71d4d9b8a27f814 Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Fri, 10 Jan 2025 11:09:22 -0800 Subject: [PATCH 09/11] Merge remote-tracking branch 'origin/main' into aamirj/topics --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 08f3f5d9..a32562a0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,4 +14,4 @@ repos: hooks: - id: mypy pass_filenames: false - args: [--ignore-missing-imports, --show-traceback] + args: [--ignore-missing-imports, --show-traceback, --verbose] From 98dbf2056dcf08a59e04f5e97f9ffd1afd00e9b8 Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Fri, 10 Jan 2025 11:12:51 -0800 Subject: [PATCH 10/11] Fix --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a32562a0..08f3f5d9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,4 +14,4 @@ repos: hooks: - id: mypy pass_filenames: false - args: [--ignore-missing-imports, --show-traceback, --verbose] + args: [--ignore-missing-imports, --show-traceback] From 893330a2d63766093fa4a39527cbdc95fb26a824 Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Fri, 10 Jan 2025 11:18:37 -0800 Subject: [PATCH 11/11] Fix test --- tests/memory_module/test_memory_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/memory_module/test_memory_module.py b/tests/memory_module/test_memory_module.py index de4be922..de94504e 100644 --- a/tests/memory_module/test_memory_module.py +++ b/tests/memory_module/test_memory_module.py @@ -300,10 +300,10 @@ async def _validate_decision(memory_module, message: List[UserMessageInput], exp assert extraction.action == "add" and extraction.facts for fact in extraction.facts: decision = await memory_module.memory_core._get_add_memory_processing_decision(fact, "user-123") - if decision != expected_decision: + if decision.decision != expected_decision: # Adding this because this test is flaky and it would be good to know why. print(f"Decision: {decision}, Expected: {expected_decision}", fact, decision) - assert decision == expected_decision + assert decision.decision == expected_decision conversation_id = str(uuid4()) old_messages = [