diff --git a/packages/evals/benchmark_memory_module.py b/packages/evals/benchmark_memory_module.py index fb91fd11..a93a2740 100644 --- a/packages/evals/benchmark_memory_module.py +++ b/packages/evals/benchmark_memory_module.py @@ -16,9 +16,9 @@ from memory_module.config import LLMConfig, MemoryModuleConfig from memory_module.core.memory_module import MemoryModule -from memory_module.interfaces.types import AssistantMessage, UserMessage +from memory_module.interfaces.types import AssistantMessage, RetrievalConfig, UserMessage -from evals.helpers import Dataset, DatasetItem, load_dataset, setup_mlflow +from evals.helpers import Dataset, DatasetItem, SessionMessage, load_dataset, setup_mlflow from evals.metrics import string_check_metric setup_mlflow(experiment_name="memory_module") @@ -48,7 +48,7 @@ def __exit__(self, exc_type, exc_value, traceback): os.remove(self._db_path) -async def add_messages(memory_module: MemoryModule, messages: List[dict]): +async def add_messages(memory_module: MemoryModule, messages: List[SessionMessage]): def create_message(**kwargs): params = { "id": str(uuid.uuid4()), @@ -94,7 +94,7 @@ async def benchmark_memory_module(input: DatasetItem): # buffer size has to be the same as the session length to trigger sm processing with MemoryModuleManager(buffer_size=len(session)) as memory_module: await add_messages(memory_module, messages=session) - memories = await memory_module.retrieve_memories(query, user_id=None, limit=None) + memories = await memory_module.retrieve_memories(None, RetrievalConfig(query=query, limit=None)) return { "input": { diff --git a/packages/memory_module/__init__.py b/packages/memory_module/__init__.py index cc9d2ae3..5fd170e1 100644 --- a/packages/memory_module/__init__.py +++ b/packages/memory_module/__init__.py @@ -9,6 +9,7 @@ Memory, Message, MessageInput, + RetrievalConfig, ShortTermMemoryRetrievalConfig, UserMessage, UserMessageInput, @@ -29,6 +30,7 @@ "MessageInput", "AssistantMessage", "AssistantMessageInput", + "RetrievalConfig", "ShortTermMemoryRetrievalConfig", "MemoryMiddleware", ] diff --git a/packages/memory_module/config.py b/packages/memory_module/config.py index e3494ed9..0cd2d83c 100644 --- a/packages/memory_module/config.py +++ b/packages/memory_module/config.py @@ -3,6 +3,8 @@ from pydantic import BaseModel, ConfigDict, Field +from memory_module.interfaces.types import Topic + class LLMConfig(BaseModel): """Configuration for LLM service.""" @@ -16,6 +18,18 @@ class LLMConfig(BaseModel): embedding_model: Optional[str] = None +DEFAULT_TOPICS = [ + Topic( + name="General Interests and Preferences", + description="When a user mentions specific events or actions, focus on the underlying interests, hobbies, or preferences they reveal (e.g., if the user mentions attending a conference, focus on the topic of the conference, not the date or location).", # noqa: E501 + ), + Topic( + name="General Facts about the user", + description="Facts that describe relevant information about the user, such as details about where they live or things they own.", # noqa: E501 + ), +] + + class MemoryModuleConfig(BaseModel): """Configuration for memory module components. @@ -35,4 +49,7 @@ class MemoryModuleConfig(BaseModel): description="Seconds to wait before processing a conversation", ) llm: LLMConfig = Field(description="LLM service configuration") + topics: list[Topic] = Field( + default=DEFAULT_TOPICS, description="List of topics that the memory module should listen to", min_length=1 + ) enable_logging: bool = Field(default=False, description="Enable verbose logging for memory module") diff --git a/packages/memory_module/core/memory_core.py b/packages/memory_module/core/memory_core.py index 9c12912a..f15aff13 100644 --- a/packages/memory_module/core/memory_core.py +++ b/packages/memory_module/core/memory_core.py @@ -10,12 +10,14 @@ from memory_module.interfaces.base_memory_storage import BaseMemoryStorage from memory_module.interfaces.types import ( BaseMemoryInput, - EmbedText, Memory, MemoryType, Message, MessageInput, + RetrievalConfig, ShortTermMemoryRetrievalConfig, + TextEmbedding, + Topic, ) from memory_module.services.llm_service import LLMService from memory_module.storage.in_memory_storage import InMemoryStorage @@ -51,6 +53,11 @@ class SemanticFact(BaseModel): default_factory=set, description="The indices of the messages that the fact was extracted from.", ) + # TODO: Add a validator to ensure that topics are valid + topics: Optional[List[str]] = Field( + default=None, + description="The name of the topic that the fact is most relevant to.", # noqa: E501 + ) class SemanticMemoryExtraction(BaseModel): @@ -106,6 +113,7 @@ def __init__( self.storage: BaseMemoryStorage = storage or ( SQLiteMemoryStorage(db_path=config.db_path) if config.db_path is not None else InMemoryStorage() ) + self.topics = config.topics async def process_semantic_messages( self, @@ -133,7 +141,7 @@ async def process_semantic_messages( if extraction.action == "add" and extraction.facts: for fact in extraction.facts: - decision = await self._get_add_memory_processing_decision(fact.text, author_id) + decision = await self._get_add_memory_processing_decision(fact, author_id) if decision.decision == "ignore": logger.info(f"Decision to ignore fact {fact.text}") continue @@ -145,6 +153,7 @@ async def process_semantic_messages( user_id=author_id, message_attributions=list(message_ids), memory_type=MemoryType.SEMANTIC, + topics=fact.topics, ) embed_vectors = await self._get_semantic_fact_embeddings(fact.text, metadata) await self.storage.store_memory(memory, embedding_vectors=embed_vectors) @@ -154,7 +163,22 @@ async def process_episodic_messages(self, messages: List[Message]) -> None: # TODO: Implement episodic memory processing await self._extract_episodic_memory_from_messages(messages) - async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]: + async def retrieve_memories( + self, + user_id: Optional[str], + config: RetrievalConfig, + ) -> List[Memory]: + return await self._retrieve_memories( + user_id, config.query, [config.topic] if config.topic else None, config.limit + ) + + async def _retrieve_memories( + self, + user_id: Optional[str], + query: Optional[str], + topics: Optional[List[Topic]], + limit: Optional[int], + ) -> List[Memory]: """Retrieve memories based on a query. Steps: @@ -162,8 +186,14 @@ async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Opt 2. Find relevant memories 3. Possibly rerank or filter results """ - embedText = EmbedText(text=query, embedding_vector=await self._get_query_embedding(query)) - return await self.storage.retrieve_memories(embedText, user_id, limit) + if query: + text_embedding = TextEmbedding(text=query, embedding_vector=await self._get_query_embedding(query)) + else: + text_embedding = None + + return await self.storage.retrieve_memories( + user_id=user_id, text_embedding=text_embedding, topics=topics, limit=limit + ) async def update_memory(self, memory_id: str, updated_memory: str) -> None: metadata = await self._extract_metadata_from_fact(updated_memory) @@ -195,10 +225,13 @@ async def remove_messages(self, message_ids: List[str]) -> None: logger.info("messages {} are removed".format(",".join(message_ids))) async def _get_add_memory_processing_decision( - self, new_memory_fact: str, user_id: Optional[str] + self, new_memory_fact: SemanticFact, user_id: Optional[str] ) -> ProcessSemanticMemoryDecision: - similar_memories = await self.retrieve_memories(new_memory_fact, user_id, None) - decision = await self._extract_memory_processing_decision(new_memory_fact, similar_memories, user_id) + # topics = ( + # [topic for topic in self.topics if topic.name in new_memory_fact.topics] if new_memory_fact.topics else None # noqa: E501 + # ) + similar_memories = await self._retrieve_memories(user_id, new_memory_fact.text, None, None) + decision = await self._extract_memory_processing_decision(new_memory_fact.text, similar_memories, user_id) return decision async def _extract_memory_processing_decision( @@ -306,6 +339,9 @@ async def _extract_semantic_fact_from_messages( else: # we explicitly ignore internal messages continue + topics = "\n".join( + [f"{topic.description}" for topic in self.topics] + ) existing_memories_str = "" if existing_memories: @@ -318,11 +354,7 @@ async def _extract_semantic_fact_from_messages( that will remain relevant over time, even if the user is mentioning short-term plans or events. Prioritize: -- General Interests and Preferences: When a user mentions specific events or actions, focus on the underlying -interests, hobbies, or preferences they reveal (e.g., if the user mentions attending a conference, focus on the topic of the conference, -not the date or location). -- Facts or Details about user: Extract facts that describe relevant information about the user, such as details about things they own. -- Facts about the user that the assistant might find useful. +{topics} Avoid: - Extraction memories that already exist in the system. If a fact is already stored, ignore it. @@ -335,7 +367,6 @@ async def _extract_semantic_fact_from_messages( {messages_str} """ # noqa: E501 - llm_messages = [ {"role": "system", "content": system_message}, { diff --git a/packages/memory_module/core/memory_module.py b/packages/memory_module/core/memory_module.py index 8163cd8f..e4e86aa8 100644 --- a/packages/memory_module/core/memory_module.py +++ b/packages/memory_module/core/memory_module.py @@ -7,7 +7,13 @@ from memory_module.interfaces.base_memory_core import BaseMemoryCore from memory_module.interfaces.base_memory_module import BaseMemoryModule from memory_module.interfaces.base_message_queue import BaseMessageQueue -from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig +from memory_module.interfaces.types import ( + Memory, + Message, + MessageInput, + RetrievalConfig, + ShortTermMemoryRetrievalConfig, +) from memory_module.services.llm_service import LLMService from memory_module.utils.logging import configure_logging @@ -50,10 +56,14 @@ async def add_message(self, message: MessageInput) -> Message: await self.message_queue.enqueue(message_res) return message_res - async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]: + async def retrieve_memories( + self, + user_id: Optional[str], + config: RetrievalConfig, + ) -> List[Memory]: """Retrieve relevant memories based on a query.""" - logger.debug(f"retrieve memories from (query: {query}, user_id: {user_id}, limit: {limit})") - memories = await self.memory_core.retrieve_memories(query, user_id, limit) + logger.debug(f"retrieve memories from (query: {config.query}, user_id: {user_id}, limit: {config.limit})") + memories = await self.memory_core.retrieve_memories(user_id=user_id, config=config) logger.debug(f"retrieved memories: {memories}") return memories diff --git a/packages/memory_module/interfaces/base_memory_core.py b/packages/memory_module/interfaces/base_memory_core.py index 8efc02da..201c2713 100644 --- a/packages/memory_module/interfaces/base_memory_core.py +++ b/packages/memory_module/interfaces/base_memory_core.py @@ -1,7 +1,13 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional -from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig +from memory_module.interfaces.types import ( + Memory, + Message, + MessageInput, + RetrievalConfig, + ShortTermMemoryRetrievalConfig, +) class BaseMemoryCore(ABC): @@ -22,7 +28,11 @@ async def process_episodic_messages(self, messages: List[Message]) -> None: pass @abstractmethod - async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]: + async def retrieve_memories( + self, + user_id: Optional[str], + config: RetrievalConfig, + ) -> List[Memory]: """Retrieve memories based on a query.""" pass diff --git a/packages/memory_module/interfaces/base_memory_module.py b/packages/memory_module/interfaces/base_memory_module.py index fb8890e4..69902d69 100644 --- a/packages/memory_module/interfaces/base_memory_module.py +++ b/packages/memory_module/interfaces/base_memory_module.py @@ -1,7 +1,13 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional -from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig +from memory_module.interfaces.types import ( + Memory, + Message, + MessageInput, + RetrievalConfig, + ShortTermMemoryRetrievalConfig, +) class BaseMemoryModule(ABC): @@ -13,7 +19,11 @@ async def add_message(self, message: MessageInput) -> Message: pass @abstractmethod - async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]: + async def retrieve_memories( + self, + user_id: Optional[str], + config: RetrievalConfig, + ) -> List[Memory]: """Retrieve relevant memories based on a query.""" pass diff --git a/packages/memory_module/interfaces/base_memory_storage.py b/packages/memory_module/interfaces/base_memory_storage.py index 352baa21..75036c78 100644 --- a/packages/memory_module/interfaces/base_memory_storage.py +++ b/packages/memory_module/interfaces/base_memory_storage.py @@ -3,11 +3,12 @@ from memory_module.interfaces.types import ( BaseMemoryInput, - EmbedText, Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig, + TextEmbedding, + Topic, ) @@ -47,7 +48,12 @@ async def store_short_term_memory(self, message: MessageInput) -> Message: @abstractmethod async def retrieve_memories( - self, embedText: EmbedText, user_id: Optional[str], limit: Optional[int] = None + self, + *, + user_id: Optional[str], + text_embedding: Optional[TextEmbedding] = None, + topics: Optional[List[Topic]] = None, + limit: Optional[int] = None, ) -> List[Memory]: """Retrieve memories based on a query. diff --git a/packages/memory_module/interfaces/types.py b/packages/memory_module/interfaces/types.py index d7bd4ad3..933a7b84 100644 --- a/packages/memory_module/interfaces/types.py +++ b/packages/memory_module/interfaces/types.py @@ -104,6 +104,12 @@ class BaseMemoryInput(BaseModel): memory_type: MemoryType user_id: Optional[str] = None message_attributions: Optional[List[str]] = Field(default=[]) + topics: Optional[List[str]] = None + + +class Topic(BaseModel): + name: str = Field(description="A unique name of the topic that the memory module should listen to") + description: str = Field(description="Description of the topic") class Memory(BaseMemoryInput): @@ -112,12 +118,40 @@ class Memory(BaseMemoryInput): id: str -class EmbedText(BaseModel): +class TextEmbedding(BaseModel): text: str embedding_vector: List[float] -class ShortTermMemoryRetrievalConfig(BaseModel): +class RetrievalConfig(BaseModel): + """Configuration for memory retrieval operations. + + This class defines the parameters used to retrieve memories from storage. Memories can be + retrieved either by a semantic search query or by filtering for a specific topic or both. + + In case of both, the memories are retrieved by the intersection of the two sets. + """ + + query: Optional[str] = Field( + default=None, description="A natural language query to search for semantically similar memories" + ) + topic: Optional[Topic] = Field( + default=None, + description="Topic to filter memories by. Only memories tagged with this topic will be retrieved", + ) + limit: Optional[int] = Field( + default=None, + description="Maximum number of memories to retrieve. If not specified, all matching memories are returned", + ) + + @model_validator(mode="after") + def check_parameters(self) -> "RetrievalConfig": + if self.query is None and self.topic is None: + raise ValueError("Either query or topic must be provided") + return self + + +class ShortTermMemoryRetrievalConfig(RetrievalConfig): n_messages: Optional[int] = None # Number of messages to retrieve last_minutes: Optional[float] = None # Time frame in minutes before: Optional[datetime] = None # Retrieve messages up until a specific timestamp diff --git a/packages/memory_module/storage/in_memory_storage.py b/packages/memory_module/storage/in_memory_storage.py index 12c3cae6..6e8024d2 100644 --- a/packages/memory_module/storage/in_memory_storage.py +++ b/packages/memory_module/storage/in_memory_storage.py @@ -14,13 +14,14 @@ AssistantMessage, AssistantMessageInput, BaseMemoryInput, - EmbedText, InternalMessage, InternalMessageInput, Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig, + TextEmbedding, + Topic, UserMessage, UserMessageInput, ) @@ -112,26 +113,54 @@ async def store_short_term_memory(self, message: MessageInput) -> Message: return message_obj async def retrieve_memories( - self, embedText: EmbedText, user_id: Optional[str], limit: Optional[int] = None + self, + *, + user_id: Optional[str], + text_embedding: Optional[TextEmbedding] = None, + topics: Optional[List[Topic]] = None, + limit: Optional[int] = None, ) -> List[Memory]: limit = limit or self.default_limit - sorted_memories: list[_MemorySimilarity] = [] + memories = [] + + # Filter memories by user_id and topics first + filtered_memories = list(self.storage["memories"].values()) + if user_id: + filtered_memories = [m for m in filtered_memories if m.user_id == user_id] + if topics: + filtered_memories = [ + m for m in filtered_memories if m.topics and any(topic.name in m.topics for topic in topics) + ] - for memory_id, embeddings in self.storage["embeddings"].items(): - memory = self.storage["memories"][memory_id] - if user_id and memory.user_id != user_id: - continue + # If we have text_embedding, calculate similarities and sort + if text_embedding: + sorted_memories: list[_MemorySimilarity] = [] - # Find the embedding with highest similarity (lowest distance) - best_similarity = float("-inf") - for embedding in embeddings: - similarity = self._cosine_similarity(embedText.embedding_vector, embedding) - best_similarity = max(best_similarity, similarity) + for memory in filtered_memories: + embeddings = self.storage["embeddings"].get(memory.id, []) + if not embeddings: + continue - sorted_memories.append(_MemorySimilarity(memory, best_similarity)) + # Find the embedding with lowest distance + best_distance = float("inf") + for embedding in embeddings: + distance = self.l2_distance(text_embedding.embedding_vector, embedding) + best_distance = min(best_distance, distance) - sorted_memories.sort(key=lambda x: x.similarity, reverse=True) - return [Memory(**item.memory.__dict__) for item in sorted_memories[:limit]] + # Filter based on distance threshold + if best_distance > 1.0: # adjust threshold as needed + continue + + sorted_memories.append(_MemorySimilarity(memory, best_distance)) + + # Sort by distance (ascending instead of descending) + sorted_memories.sort(key=lambda x: x.similarity) + memories = [Memory(**item.memory.__dict__) for item in sorted_memories[:limit]] + else: + # If no embedding, sort by created_at + memories = sorted(filtered_memories, key=lambda x: x.created_at, reverse=True)[:limit] + + return memories async def get_memories(self, memory_ids: List[str]) -> List[Memory]: return [ @@ -169,8 +198,12 @@ async def remove_memories(self, memory_ids: List[str]) -> None: self.storage["embeddings"].pop(memory_id, None) self.storage["memories"].pop(memory_id, None) - def _cosine_similarity(self, memory_vector: List[float], query_vector: List[float]) -> float: - return np.dot(np.array(query_vector), np.array(memory_vector)) + def l2_distance(self, memory_vector: List[float], query_vector: List[float]) -> float: + memory_array = np.array(memory_vector) + query_array = np.array(query_vector) + + # Compute L2 (Euclidean) distance: sqrt(sum((a-b)^2)) + return np.sqrt(np.sum((memory_array - query_array) ** 2)) async def clear_memories(self, user_id: str) -> None: memory_ids_for_user = [ diff --git a/packages/memory_module/storage/migrations/14_add_topic_to_memories.sql b/packages/memory_module/storage/migrations/14_add_topic_to_memories.sql new file mode 100644 index 00000000..511d49f6 --- /dev/null +++ b/packages/memory_module/storage/migrations/14_add_topic_to_memories.sql @@ -0,0 +1 @@ +ALTER TABLE memories ADD COLUMN topics TEXT; \ No newline at end of file diff --git a/packages/memory_module/storage/sqlite_memory_storage.py b/packages/memory_module/storage/sqlite_memory_storage.py index eef8a445..d7e92b07 100644 --- a/packages/memory_module/storage/sqlite_memory_storage.py +++ b/packages/memory_module/storage/sqlite_memory_storage.py @@ -8,12 +8,13 @@ from memory_module.interfaces.base_memory_storage import BaseMemoryStorage from memory_module.interfaces.types import ( BaseMemoryInput, - EmbedText, InternalMessageInput, Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig, + TextEmbedding, + Topic, ) from memory_module.storage.sqlite_storage import SQLiteStorage from memory_module.storage.utils import build_message_from_dict @@ -40,16 +41,20 @@ async def store_memory(self, memory: BaseMemoryInput, *, embedding_vectors: List async with self.storage.transaction() as cursor: # Store the memory + # Convert topics list to comma-separated string if it's a list + topics_str = ",".join(memory.topics) if isinstance(memory.topics, list) else memory.topics + await cursor.execute( """INSERT INTO memories - (id, content, created_at, user_id, memory_type) - VALUES (?, ?, ?, ?, ?)""", + (id, content, created_at, user_id, memory_type, topics) + VALUES (?, ?, ?, ?, ?, ?)""", ( memory_id, memory.content, memory.created_at.astimezone(datetime.timezone.utc), memory.user_id, memory.memory_type.value, + topics_str, ), ) @@ -107,112 +112,112 @@ async def update_memory(self, memory_id: str, updated_memory: str, *, embedding_ ) async def retrieve_memories( - self, embedText: EmbedText, user_id: Optional[str], limit: Optional[int] = None + self, + *, + user_id: Optional[str], + text_embedding: Optional[TextEmbedding] = None, + topics: Optional[List[Topic]] = None, + limit: Optional[int] = None, ) -> List[Memory]: - """Retrieve memories based on a query.""" - query = """ - WITH ranked_memories AS ( - SELECT - e.memory_id, - distance - FROM vec_items - JOIN embeddings e ON vec_items.memory_embedding_id = e.id - WHERE vec_items.embedding MATCH ? AND K = ? AND distance < ? - ORDER BY distance ASC - ) + base_query = """ SELECT - m.id, - m.content, - m.created_at, - m.user_id, - m.memory_type, - ma.message_id, - rm.distance - FROM ranked_memories rm - JOIN memories m ON m.id = rm.memory_id + m.*, + GROUP_CONCAT(ma.message_id) as message_attributions + {distance_select} + FROM memories m LEFT JOIN memory_attributions ma ON m.id = ma.memory_id - ORDER BY rm.distance ASC + {embedding_join} + WHERE 1=1 + {topic_filter} + {user_filter} + GROUP BY m.id + {order_by} + {limit_clause} """ - rows = await self.storage.fetch_all( - query, - ( - sqlite_vec.serialize_float32(embedText.embedding_vector), - limit or self.default_limit, - 1.0, - ), + params = [] + + # Handle embedding search first since its params come first in the query + embedding_join = "" + distance_select = "" + order_by = "ORDER BY m.created_at DESC" + + if text_embedding: + embedding_join = """ + JOIN ( + SELECT + e.memory_id, + distance + FROM vec_items + JOIN embeddings e ON vec_items.memory_embedding_id = e.id + WHERE vec_items.embedding MATCH ? AND K = ? AND distance < ? + ) rm ON m.id = rm.memory_id + """ + distance_select = ", rm.distance as _distance" + order_by = "ORDER BY rm.distance ASC" + params.extend( + [sqlite_vec.serialize_float32(text_embedding.embedding_vector), limit or self.default_limit, 1.0] + ) + + # Handle topic and user filters after embedding params + topic_filter = "" + if topics: + # Create a single AND condition with multiple LIKE clauses + topic_filter = " AND (" + " OR ".join(["m.topics LIKE ?"] * len(topics)) + ")" + params.extend(f"%{t.name}%" for t in topics) + + user_filter = "" + if user_id: + user_filter = "AND m.user_id = ?" + params.append(user_id) + + # Handle limit last + limit_clause = "" + if limit and not text_embedding: # Only add LIMIT if not using vector search + limit_clause = "LIMIT ?" + params.append(limit or self.default_limit) + + query = base_query.format( + distance_select=distance_select, + embedding_join=embedding_join, + topic_filter=topic_filter, + user_filter=user_filter, + order_by=order_by, + limit_clause=limit_clause, ) - # Group rows by memory_id to handle message attributions - memories_dict = {} - for row in rows: - memory_id = row["id"] - if memory_id not in memories_dict: - memories_dict[memory_id] = { - "id": memory_id, - "content": row["content"], - "created_at": row["created_at"], - "user_id": row["user_id"], - "memory_type": row["memory_type"], - "message_attributions": [], - "distance": row["distance"], - } - - if row["message_id"]: - memories_dict[memory_id]["message_attributions"].append(row["message_id"]) - - return [Memory(**memory_data) for memory_data in memories_dict.values()] + rows = await self.storage.fetch_all(query, tuple(params)) + return [self._build_memory(row, (row["message_attributions"] or "").split(",")) for row in rows] async def clear_memories(self, user_id: str) -> None: """Clear all memories for a given user.""" query = """ SELECT - m.id, - e.id AS embed_id + m.id FROM memories m - LEFT JOIN embeddings e - WHERE m.user_id = ? AND m.id = e.memory_id + WHERE m.user_id = ? """ - id_rows = await self.storage.fetch_all(query, (user_id,)) - memory_id_list = [row["id"] for row in id_rows] - embed_id_list = [row["embed_id"] for row in id_rows] - + rows = await self.storage.fetch_all(query, (user_id,)) + memory_ids = [row["id"] for row in rows] # Remove memory - await self._remove_memories_and_embeddings(memory_id_list, ",".join(["?"] * len(embed_id_list)), embed_id_list) + await self.remove_memories(memory_ids) async def get_memory(self, memory_id: int) -> Optional[Memory]: - """Retrieve a memory with its message attributions.""" query = """ SELECT - m.id, - m.content, - m.created_at, - m.updated_at, - m.user_id, - m.memory_type, - ma.message_id + m.*, + GROUP_CONCAT(ma.message_id) as message_attributions FROM memories m LEFT JOIN memory_attributions ma ON m.id = ma.memory_id WHERE m.id = ? + GROUP BY m.id """ - rows = await self.storage.fetch_all(query, (memory_id,)) - - if not rows: + row = await self.storage.fetch_one(query, (memory_id,)) + if not row: return None - # First row contains the memory data - memory_data = { - "id": rows[0]["id"], - "content": rows[0]["content"], - "created_at": rows[0]["created_at"], - "updated_at": rows[0]["updated_at"], - "user_id": rows[0]["user_id"], - "memory_type": rows[0]["memory_type"], - "message_attributions": [row["message_id"] for row in rows if row["message_id"]], - } - - return Memory(**memory_data) + return self._build_memory(row, (row["message_attributions"] or "").split(",")) async def get_all_memories( self, limit: Optional[int] = None, message_ids: Optional[List[str]] = None @@ -220,12 +225,8 @@ async def get_all_memories( """Retrieve all memories with their message attributions.""" query = """ SELECT - m.id, - m.content, - m.created_at, - m.user_id, - m.memory_type, - ma.message_id + m.*, + GROUP_CONCAT(ma.message_id) as message_attributions FROM memories m LEFT JOIN memory_attributions ma ON m.id = ma.memory_id """ @@ -236,32 +237,17 @@ async def get_all_memories( message_ids, ) - query += " ORDER BY m.created_at DESC" + query += """ + GROUP BY m.id + ORDER BY m.created_at DESC + """ if limit is not None: query += " LIMIT ?" params += (limit,) - rows = await self.storage.fetch_all(query, tuple(params)) - - # Group rows by memory_id - memories_dict = {} - for row in rows: - memory_id = row["id"] - if memory_id not in memories_dict: - memories_dict[memory_id] = { - "id": memory_id, - "content": row["content"], - "created_at": row["created_at"], - "user_id": row["user_id"], - "memory_type": row["memory_type"], - "message_attributions": [], - } - - if row["message_id"]: - memories_dict[memory_id]["message_attributions"].append(row["message_id"]) - - return [Memory(**memory_data) for memory_data in memories_dict.values()] + rows = await self.storage.fetch_all(query, params) + return [self._build_memory(row, (row["message_attributions"] or "").split(",")) for row in rows] async def store_short_term_memory(self, message: MessageInput) -> Message: """Store a short-term memory entry.""" @@ -332,93 +318,66 @@ async def retrieve_chat_history( return [build_message_from_dict(row) for row in rows][::-1] async def get_memories(self, memory_ids: List[str]) -> List[Memory]: - query = """ + query = f""" SELECT - m.id, - m.content, - m.created_at, - m.user_id, - m.memory_type, - ma.message_id + m.*, + GROUP_CONCAT(ma.message_id) as message_attributions FROM memories m LEFT JOIN memory_attributions ma ON m.id = ma.memory_id - WHERE m.id IN ({}) - """.format(",".join(["?"] * len(memory_ids))) + WHERE m.id IN ({','.join(['?'] * len(memory_ids))}) + """ rows = await self.storage.fetch_all(query, tuple(memory_ids)) - # Group rows by memory_id - memories_dict = {} - for row in rows: - memory_id = row["id"] - if memory_id not in memories_dict: - memories_dict[memory_id] = { - "id": memory_id, - "content": row["content"], - "created_at": row["created_at"], - "user_id": row["user_id"], - "memory_type": row["memory_type"], - "message_attributions": [], - } - - if row["message_id"]: - memories_dict[memory_id]["message_attributions"].append(row["message_id"]) - - return [Memory(**memory_data) for memory_data in memories_dict.values()] + return [self._build_memory(row, (row["message_attributions"] or "").split(",")) for row in rows] async def get_user_memories(self, user_id: str) -> List[Memory]: - """Get memories based on user id.""" query = """ SELECT m.*, - ma.message_id + GROUP_CONCAT(ma.message_id) as message_attributions FROM memories m LEFT JOIN memory_attributions ma ON m.id = ma.memory_id WHERE m.user_id = ? + GROUP BY m.id """ rows = await self.storage.fetch_all(query, (user_id,)) - - # Group rows by memory_id - memories_dict = {} - for row in rows: - memory_id = row["id"] - if memory_id not in memories_dict: - memories_dict[memory_id] = { - "id": memory_id, - "content": row["content"], - "created_at": row["created_at"], - "user_id": row["user_id"], - "memory_type": row["memory_type"], - "message_attributions": [], - } - - if row["message_id"]: - memories_dict[memory_id]["message_attributions"].append(row["message_id"]) - - return [Memory(**memory_data) for memory_data in memories_dict.values()] + return [self._build_memory(row, (row["message_attributions"] or "").split(",")) for row in rows] async def get_messages(self, memory_ids: List[str]) -> Dict[str, List[Message]]: - """Get messages based on memory ids.""" - query = """ - SELECT ma.memory_id, m.* + query = f""" + SELECT + ma.memory_id as _memory_id, + m.* FROM memory_attributions ma JOIN messages m ON ma.message_id = m.id - WHERE ma.memory_id IN ({}) - """.format(",".join(["?"] * len(memory_ids))) + WHERE ma.memory_id IN ({','.join(['?'] * len(memory_ids))}) + """ rows = await self.storage.fetch_all(query, tuple(memory_ids)) messages_dict: Dict[str, List[Message]] = {} for row in rows: - memory_id = row["memory_id"] + memory_id = row["_memory_id"] if memory_id not in messages_dict: messages_dict[memory_id] = [] - - messages_dict[memory_id].append(build_message_from_dict(row)) + messages_dict[memory_id].append( + build_message_from_dict({k: v for k, v in row.items() if not k.startswith("_")}) + ) return messages_dict + def _build_memory(self, memory_values: dict, message_attributions: List[str]) -> Memory: + memory_keys = ["id", "content", "created_at", "user_id", "memory_type", "topics"] + # Convert topics string back to list if it exists + if memory_values.get("topics"): + memory_values["topics"] = memory_values["topics"].split(",") + return Memory( + **{k: v for k, v in memory_values.items() if k in memory_keys}, + message_attributions=message_attributions, + ) + async def remove_messages(self, message_ids: List[str]) -> None: async with self.storage.transaction() as cursor: await cursor.execute( @@ -427,21 +386,12 @@ async def remove_messages(self, message_ids: List[str]) -> None: ) async def remove_memories(self, memory_ids: List[str]) -> None: - query = """ - SELECT - id - FROM embeddings - WHERE memory_id in ({}) - """.format(",".join(["?"] * len(memory_ids))) - await self._remove_memories_and_embeddings(memory_ids, query) - - async def _remove_memories_and_embeddings( - self, memory_ids: List[str], embed_query: str, embed_ids: Optional[List[str]] = None - ) -> None: async with self.storage.transaction() as cursor: await cursor.execute( - "DELETE FROM vec_items WHERE memory_embedding_id in ({})".format(embed_query), - tuple(embed_ids or memory_ids), + """DELETE FROM vec_items WHERE memory_embedding_id in ( + SELECT id FROM embeddings WHERE memory_id in ({}) + )""".format(",".join(["?"] * len(memory_ids))), + tuple(memory_ids), ) await cursor.execute( diff --git a/src/tech_assistant_agent/tools.py b/src/tech_assistant_agent/tools.py index 4e3193a4..072f175f 100644 --- a/src/tech_assistant_agent/tools.py +++ b/src/tech_assistant_agent/tools.py @@ -3,7 +3,7 @@ from botbuilder.core import TurnContext from botbuilder.schema import Activity -from memory_module import BaseMemoryModule, Memory +from memory_module import BaseMemoryModule, Memory, RetrievalConfig from pydantic import BaseModel, Field from teams.ai.citations import AIEntity, Appearance, ClientCitation @@ -49,7 +49,7 @@ async def get_candidate_tasks(candidate_tasks: GetCandidateTasks) -> str: async def get_memorized_fields(memory_module: BaseMemoryModule, fields_to_retrieve: GetMemorizedFields) -> str: empty_obj: dict = {} for query in fields_to_retrieve.queries_for_fields: - result = await memory_module.retrieve_memories(query, None, None) + result = await memory_module.retrieve_memories(None, RetrievalConfig(query=query, limit=None)) print("Getting memorized queries: ", query) print(result) print("---") diff --git a/tests/memory_module/test_in_memory_storage.py b/tests/memory_module/test_in_memory_storage.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/memory_module/test_memory_module.py b/tests/memory_module/test_memory_module.py index 08940b42..de94504e 100644 --- a/tests/memory_module/test_memory_module.py +++ b/tests/memory_module/test_memory_module.py @@ -11,7 +11,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from memory_module.config import MemoryModuleConfig +from memory_module.config import DEFAULT_TOPICS, MemoryModuleConfig, Topic # noqa: I001 from memory_module.core.memory_core import ( EpisodicMemoryExtraction, MemoryCore, @@ -21,6 +21,8 @@ ) from memory_module.core.memory_module import MemoryModule from memory_module.interfaces.types import ( + AssistantMessageInput, + RetrievalConfig, ShortTermMemoryRetrievalConfig, UserMessageInput, ) @@ -33,16 +35,21 @@ @pytest.fixture -def config(): +def config(request): """Fixture to create test config.""" + params = request.param if hasattr(request, "param") else {} llm_config = build_llm_config() + buffer_size = params.get("buffer_size", 5) + timeout_seconds = params.get("timeout_seconds", 60) + topics = params.get("topics", DEFAULT_TOPICS) if not llm_config.api_key: pytest.skip("OpenAI API key not provided") return MemoryModuleConfig( db_path=Path(__file__).parent / "data" / "tests" / "memory_module.db", - buffer_size=5, - timeout_seconds=60, + buffer_size=buffer_size, + timeout_seconds=timeout_seconds, llm=llm_config, + topics=topics, ) @@ -155,7 +162,7 @@ async def test_simple_conversation(memory_module): assert any(message.id in stored_memories[0].message_attributions for message in messages) assert all(memory.memory_type == "semantic" for memory in stored_memories) - result = await memory_module.retrieve_memories("apple pie", "", 1) + result = await memory_module.retrieve_memories("user-123", RetrievalConfig(query="apple pie", limit=1)) assert len(result) == 1 assert result[0].id == next(memory.id for memory in stored_memories if "apple pie" in memory.content) @@ -287,35 +294,46 @@ async def test_short_term_memory(memory_module): @pytest.mark.asyncio async def test_add_memory_processing_decision(memory_module): """Test whether to process adding memory""" + + async def _validate_decision(memory_module, message: List[UserMessageInput], expected_decision: str): + extraction = await memory_module.memory_core._extract_semantic_fact_from_messages(message) + assert extraction.action == "add" and extraction.facts + for fact in extraction.facts: + decision = await memory_module.memory_core._get_add_memory_processing_decision(fact, "user-123") + if decision.decision != expected_decision: + # Adding this because this test is flaky and it would be good to know why. + print(f"Decision: {decision}, Expected: {expected_decision}", fact, decision) + assert decision.decision == expected_decision + conversation_id = str(uuid4()) old_messages = [ UserMessageInput( id=str(uuid4()), - content="I have a Pokemon limited version Mac book.", + content="I have a Pokemon limited version Macbook.", author_id="user-123", conversation_ref=conversation_id, - created_at=datetime.strptime("2024-09-01", "%Y-%m-%d"), + created_at=datetime.now() - timedelta(minutes=3), ), UserMessageInput( id=str(uuid4()), content="I bought a pink iphone.", author_id="user-123", conversation_ref=conversation_id, - created_at=datetime.strptime("2024-09-03", "%Y-%m-%d"), + created_at=datetime.now() - timedelta(minutes=2), ), UserMessageInput( id=str(uuid4()), - content="I just had another Mac book.", + content="I just bought a Macbook.", author_id="user-123", conversation_ref=conversation_id, - created_at=datetime.strptime("2024-10-12", "%Y-%m-%d"), + created_at=datetime.now() - timedelta(minutes=1), ), ] new_messages = [ [ UserMessageInput( id=str(uuid4()), - content="I have a Mac book", + content="I have a Macbook", author_id="user-123", conversation_ref=conversation_id, created_at=datetime.now(), @@ -341,14 +359,6 @@ async def test_add_memory_processing_decision(memory_module): await _validate_decision(memory_module, new_messages[1], "add") -async def _validate_decision(memory_module, message: List[UserMessageInput], expected_decision: str): - extraction = await memory_module.memory_core._extract_semantic_fact_from_messages(message) - assert extraction.action == "add" and extraction.facts - for fact in extraction.facts: - decision = await memory_module.memory_core._get_add_memory_processing_decision(fact.text, "user-123") - assert decision.decision == expected_decision - - @pytest.mark.asyncio async def test_remove_messages(memory_module): conversation1_id = str(uuid4()) @@ -420,3 +430,190 @@ async def test_remove_messages(memory_module): conversation_refs = list(updated_buffer.keys()) assert len(conversation_refs) == 1 assert conversation_refs[0] == conversation3_id + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "config", + [ + { + "topics": [ + Topic(name="Device Type", description="The type of device the user has"), + Topic(name="Operating System", description="The user's operating system"), + Topic(name="Device year", description="The year of the user's device"), + ], + "buffer_size": 10, + } + ], + indirect=True, +) +async def test_topic_extraction(memory_module): + conversation_id = str(uuid4()) + messages = [ + {"role": "user", "content": "I need help with my device..."}, + {"role": "assistant", "content": "I'm sorry to hear that. What device do you have?"}, + {"role": "user", "content": "I have a Macbook"}, + {"role": "assistant", "content": "What is the year of your device?"}, + {"role": "user", "content": "2024"}, + ] + + for message in messages: + if message["role"] == "user": + input = UserMessageInput( + id=str(uuid4()), + content=message["content"], + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now(), + ) + else: + input = AssistantMessageInput( + id=str(uuid4()), + content=message["content"], + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now(), + ) + await memory_module.add_message(input) + + await memory_module.message_queue.message_buffer.scheduler.flush() + stored_memories = await memory_module.memory_core.storage.get_all_memories() + assert any("macbook" in message.content.lower() for message in stored_memories) + assert any("2024" in message.content for message in stored_memories) + + # Add assertions for topics + device_type_memory = next((m for m in stored_memories if "macbook" in m.content.lower()), None) + year_memory = next((m for m in stored_memories if "2024" in m.content), None) + + assert device_type_memory is not None and "Device Type" in device_type_memory.topics + assert year_memory is not None and "Device year" in year_memory.topics + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "config", + [ + { + "topics": [ + Topic(name="Device Type", description="The type of device the user has"), + Topic(name="Operating System", description="The user's operating system"), + Topic(name="Device year", description="The year of the user's device"), + ], + "buffer_size": 10, + } + ], + indirect=True, +) +async def test_retrieve_memories_by_topic(memory_module): + """Test retrieving memories by topic only.""" + conversation_id = str(uuid4()) + messages = [ + UserMessageInput( + id=str(uuid4()), + content="I use Windows 11 on my PC", + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now() - timedelta(minutes=5), + ), + UserMessageInput( + id=str(uuid4()), + content="I have a MacBook Pro from 2023", + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now() - timedelta(minutes=3), + ), + UserMessageInput( + id=str(uuid4()), + content="My MacBook runs macOS Sonoma", + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now() - timedelta(minutes=1), + ), + ] + + for message in messages: + await memory_module.add_message(message) + await memory_module.message_queue.message_buffer.scheduler.flush() + + # Retrieve memories by Operating System topic + os_memories = await memory_module.retrieve_memories( + "user-123", + RetrievalConfig(topic=Topic(name="Operating System", description="The user's operating system")), + ) + + assert all("Operating System" in memory.topics for memory in os_memories) + assert any("windows 11" in memory.content.lower() for memory in os_memories) + assert any("sonoma" in memory.content.lower() for memory in os_memories) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "config", + [ + { + "topics": [ + Topic(name="Device Type", description="The type of device the user has"), + Topic(name="Operating System", description="The user's operating system"), + Topic(name="Device year", description="The year of the user's device"), + ], + "buffer_size": 10, + } + ], + indirect=True, +) +async def test_retrieve_memories_by_topic_and_query(memory_module): + """Test retrieving memories using both topic and semantic search.""" + conversation_id = str(uuid4()) + messages = [ + UserMessageInput( + id=str(uuid4()), + content="I use Windows 11 on my gaming PC", + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now() - timedelta(minutes=5), + ), + UserMessageInput( + id=str(uuid4()), + content="I have a MacBook Pro from 2023", + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now() - timedelta(minutes=3), + ), + UserMessageInput( + id=str(uuid4()), + content="My MacBook runs macOS Sonoma", + author_id="user-123", + conversation_ref=conversation_id, + created_at=datetime.now() - timedelta(minutes=1), + ), + ] + + for message in messages: + await memory_module.add_message(message) + await memory_module.message_queue.message_buffer.scheduler.flush() + + # Retrieve memories by Operating System topic AND query about Mac + memories = await memory_module.retrieve_memories( + "user-123", + RetrievalConfig( + topic=Topic(name="Operating System", description="The user's operating system"), + query="MacBook", + ), + ) + assert len(memories) == 1 + most_relevant_memory = memories[0] + assert "macOS" in most_relevant_memory.content + assert "Windows" not in most_relevant_memory.content + + # Try another query within the same topic + windows_memories = await memory_module.retrieve_memories( + "user-123", + RetrievalConfig( + topic=Topic(name="Operating System", description="The user's operating system"), + query="What is the operating system for the user's Windows PC?", + ), + ) + + most_relevant_memory = windows_memories[0] + assert "Windows" in most_relevant_memory.content + assert "macOS" not in most_relevant_memory.content diff --git a/tests/memory_module/test_memory_storage.py b/tests/memory_module/test_memory_storage.py index 0b4fe067..506c8c5b 100644 --- a/tests/memory_module/test_memory_storage.py +++ b/tests/memory_module/test_memory_storage.py @@ -6,9 +6,10 @@ from memory_module.interfaces.types import ( AssistantMessageInput, BaseMemoryInput, - EmbedText, MemoryType, ShortTermMemoryRetrievalConfig, + TextEmbedding, + Topic, UserMessageInput, ) from memory_module.storage.in_memory_storage import InMemoryStorage @@ -90,10 +91,10 @@ async def test_retrieve_memories(memory_storage, sample_memory_input, sample_emb await memory_storage.store_memory(sample_memory_input, embedding_vectors=sample_embedding) # Create query embedding - query = EmbedText(text="test query", embedding_vector=sample_embedding[0]) + query = TextEmbedding(text="test query", embedding_vector=sample_embedding[0]) # Retrieve memories - memories = await memory_storage.retrieve_memories(query, "test_user", limit=1) + memories = await memory_storage.retrieve_memories(user_id="test_user", text_embedding=query, limit=1) assert len(memories) > 0 assert memories[0].content == sample_memory_input.content @@ -109,9 +110,9 @@ async def test_retrieve_memories_multiple_embeddings(memory_storage, sample_memo await memory_storage.store_memory(sample_memory_input, embedding_vectors=embeddings) # Query should match the second embedding better - query = EmbedText(text="test query", embedding_vector=[1.0] * 1536) + query = TextEmbedding(text="test query", embedding_vector=[1.0] * 1536) - memories = await memory_storage.retrieve_memories(query, "test_user", limit=1) + memories = await memory_storage.retrieve_memories(user_id="test_user", text_embedding=query, limit=1) assert len(memories) == 1 @@ -327,3 +328,167 @@ async def test_get_messages(memory_storage): assert result[memory_id_1][0].id == "msg1" assert result[memory_id_1][1].id == "msg2" assert result[memory_id_2][0].id == "msg2" + + +@pytest.mark.asyncio +async def test_retrieve_memories_by_topic(memory_storage, sample_embedding): + # Store memories with different topics + memory1 = BaseMemoryInput( + content="Memory about AI", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["AI"], + ) + memory2 = BaseMemoryInput( + content="Memory about nature", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["nature"], + ) + + await memory_storage.store_memory(memory1, embedding_vectors=sample_embedding) + await memory_storage.store_memory(memory2, embedding_vectors=sample_embedding) + + # Retrieve memories by single topic + memories = await memory_storage.retrieve_memories( + user_id="test_user", topics=[Topic(name="AI", description="")], limit=10 + ) + assert len(memories) == 1 + assert memories[0].content == "Memory about AI" + assert "AI" in memories[0].topics + + # Test with non-existent topic + memories = await memory_storage.retrieve_memories( + user_id="test_user", topics=[Topic(name="non_existent_topic", description="")], limit=10 + ) + assert len(memories) == 0 + + +@pytest.mark.asyncio +async def test_retrieve_memories_by_topic_and_embedding(memory_storage, sample_embedding): + # Store memories with different topics + memory1 = BaseMemoryInput( + content="Technical discussion about artificial intelligence", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["AI"], + ) + memory2 = BaseMemoryInput( + content="Another AI related memory but less relevant", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["AI"], + ) + + await memory_storage.store_memory(memory1, embedding_vectors=sample_embedding) + await memory_storage.store_memory(memory2, embedding_vectors=[[0.2] * 1536]) # Less similar embedding + + # Create query embedding + query = TextEmbedding(text="AI technology", embedding_vector=sample_embedding[0]) + + # Retrieve memories using both topic and semantic similarity + memories = await memory_storage.retrieve_memories( + user_id="test_user", text_embedding=query, topics=[Topic(name="AI", description="")], limit=2 + ) + + assert len(memories) == 1 + assert memories[0].content == "Technical discussion about artificial intelligence" + + +@pytest.mark.asyncio +async def test_retrieve_memories_with_multiple_topics(memory_storage, sample_embedding): + # Store memories with multiple topics + memory1 = BaseMemoryInput( + content="Memory about AI and robotics", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["AI", "robotics"], + ) + memory2 = BaseMemoryInput( + content="Memory about AI and machine learning", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["AI", "machine learning"], + ) + memory3 = BaseMemoryInput( + content="Memory about nature", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["nature"], + ) + + await memory_storage.store_memory(memory1, embedding_vectors=sample_embedding) + await memory_storage.store_memory(memory2, embedding_vectors=sample_embedding) + await memory_storage.store_memory(memory3, embedding_vectors=sample_embedding) + + # Retrieve memories by AI topic (should get both AI-related memories) + memories = await memory_storage.retrieve_memories( + user_id="test_user", topics=[Topic(name="AI", description="")], limit=10 + ) + assert len(memories) == 2 + assert all("AI" in memory.topics for memory in memories) + + # Retrieve memories by robotics topic (should get only the robotics memory) + memories = await memory_storage.retrieve_memories( + user_id="test_user", topics=[Topic(name="robotics", description="")], limit=10 + ) + assert len(memories) == 1 + assert "robotics" in memories[0].topics + assert memories[0].content == "Memory about AI and robotics" + + +@pytest.mark.asyncio +async def test_retrieve_memories_with_multiple_topics_parameter(memory_storage, sample_embedding): + # Store memories with multiple topics + memory1 = BaseMemoryInput( + content="Memory about AI and robotics", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["AI", "robotics"], + ) + memory2 = BaseMemoryInput( + content="Memory about AI and machine learning", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["AI", "machine learning"], + ) + memory3 = BaseMemoryInput( + content="Memory about nature", + created_at=datetime.now(), + user_id="test_user", + memory_type=MemoryType.SEMANTIC, + message_attributions=[], + topics=["nature"], + ) + + await memory_storage.store_memory(memory1, embedding_vectors=sample_embedding) + await memory_storage.store_memory(memory2, embedding_vectors=sample_embedding) + await memory_storage.store_memory(memory3, embedding_vectors=sample_embedding) + + # Retrieve memories by multiple topics + memories = await memory_storage.retrieve_memories( + user_id="test_user", topics=[Topic(name="AI", description=""), Topic(name="robotics", description="")], limit=10 + ) + + # Should get both AI-related memories + assert len(memories) == 2 + assert any("robotics" in memory.topics for memory in memories) + assert all("AI" in memory.topics for memory in memories)