Skip to content

Add ability to specify topics for memory extraction and retrieval #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 10, 2025
8 changes: 4 additions & 4 deletions packages/evals/benchmark_memory_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

from memory_module.config import LLMConfig, MemoryModuleConfig
from memory_module.core.memory_module import MemoryModule
from memory_module.interfaces.types import AssistantMessage, UserMessage
from memory_module.interfaces.types import AssistantMessage, RetrievalConfig, UserMessage

from evals.helpers import Dataset, DatasetItem, load_dataset, setup_mlflow
from evals.helpers import Dataset, DatasetItem, SessionMessage, load_dataset, setup_mlflow
from evals.metrics import string_check_metric

setup_mlflow(experiment_name="memory_module")
Expand Down Expand Up @@ -48,7 +48,7 @@ def __exit__(self, exc_type, exc_value, traceback):
os.remove(self._db_path)


async def add_messages(memory_module: MemoryModule, messages: List[dict]):
async def add_messages(memory_module: MemoryModule, messages: List[SessionMessage]):
def create_message(**kwargs):
params = {
"id": str(uuid.uuid4()),
Expand Down Expand Up @@ -94,7 +94,7 @@ async def benchmark_memory_module(input: DatasetItem):
# buffer size has to be the same as the session length to trigger sm processing
with MemoryModuleManager(buffer_size=len(session)) as memory_module:
await add_messages(memory_module, messages=session)
memories = await memory_module.retrieve_memories(query, user_id=None, limit=None)
memories = await memory_module.retrieve_memories(None, RetrievalConfig(query=query, limit=None))

return {
"input": {
Expand Down
2 changes: 2 additions & 0 deletions packages/memory_module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Memory,
Message,
MessageInput,
RetrievalConfig,
ShortTermMemoryRetrievalConfig,
UserMessage,
UserMessageInput,
Expand All @@ -29,6 +30,7 @@
"MessageInput",
"AssistantMessage",
"AssistantMessageInput",
"RetrievalConfig",
"ShortTermMemoryRetrievalConfig",
"MemoryMiddleware",
]
17 changes: 17 additions & 0 deletions packages/memory_module/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from pydantic import BaseModel, ConfigDict, Field

from memory_module.interfaces.types import Topic


class LLMConfig(BaseModel):
"""Configuration for LLM service."""
Expand All @@ -16,6 +18,18 @@ class LLMConfig(BaseModel):
embedding_model: Optional[str] = None


DEFAULT_TOPICS = [
Topic(
name="General Interests and Preferences",
description="When a user mentions specific events or actions, focus on the underlying interests, hobbies, or preferences they reveal (e.g., if the user mentions attending a conference, focus on the topic of the conference, not the date or location).", # noqa: E501
),
Topic(
name="General Facts about the user",
description="Facts that describe relevant information about the user, such as details about where they live or things they own.", # noqa: E501
),
]


class MemoryModuleConfig(BaseModel):
"""Configuration for memory module components.

Expand All @@ -35,4 +49,7 @@ class MemoryModuleConfig(BaseModel):
description="Seconds to wait before processing a conversation",
)
llm: LLMConfig = Field(description="LLM service configuration")
topics: list[Topic] = Field(
default=DEFAULT_TOPICS, description="List of topics that the memory module should listen to", min_length=1
)
enable_logging: bool = Field(default=False, description="Enable verbose logging for memory module")
59 changes: 45 additions & 14 deletions packages/memory_module/core/memory_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
from memory_module.interfaces.base_memory_storage import BaseMemoryStorage
from memory_module.interfaces.types import (
BaseMemoryInput,
EmbedText,
Memory,
MemoryType,
Message,
MessageInput,
RetrievalConfig,
ShortTermMemoryRetrievalConfig,
TextEmbedding,
Topic,
)
from memory_module.services.llm_service import LLMService
from memory_module.storage.in_memory_storage import InMemoryStorage
Expand Down Expand Up @@ -51,6 +53,11 @@ class SemanticFact(BaseModel):
default_factory=set,
description="The indices of the messages that the fact was extracted from.",
)
# TODO: Add a validator to ensure that topics are valid
topics: Optional[List[str]] = Field(
default=None,
description="The name of the topic that the fact is most relevant to.", # noqa: E501
)


class SemanticMemoryExtraction(BaseModel):
Expand Down Expand Up @@ -106,6 +113,7 @@ def __init__(
self.storage: BaseMemoryStorage = storage or (
SQLiteMemoryStorage(db_path=config.db_path) if config.db_path is not None else InMemoryStorage()
)
self.topics = config.topics

async def process_semantic_messages(
self,
Expand Down Expand Up @@ -133,7 +141,7 @@ async def process_semantic_messages(

if extraction.action == "add" and extraction.facts:
for fact in extraction.facts:
decision = await self._get_add_memory_processing_decision(fact.text, author_id)
decision = await self._get_add_memory_processing_decision(fact, author_id)
if decision.decision == "ignore":
logger.info(f"Decision to ignore fact {fact.text}")
continue
Expand All @@ -145,6 +153,7 @@ async def process_semantic_messages(
user_id=author_id,
message_attributions=list(message_ids),
memory_type=MemoryType.SEMANTIC,
topics=fact.topics,
)
embed_vectors = await self._get_semantic_fact_embeddings(fact.text, metadata)
await self.storage.store_memory(memory, embedding_vectors=embed_vectors)
Expand All @@ -154,16 +163,37 @@ async def process_episodic_messages(self, messages: List[Message]) -> None:
# TODO: Implement episodic memory processing
await self._extract_episodic_memory_from_messages(messages)

async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]:
async def retrieve_memories(
self,
user_id: Optional[str],
config: RetrievalConfig,
) -> List[Memory]:
return await self._retrieve_memories(
user_id, config.query, [config.topic] if config.topic else None, config.limit
)

async def _retrieve_memories(
self,
user_id: Optional[str],
query: Optional[str],
topics: Optional[List[Topic]],
limit: Optional[int],
) -> List[Memory]:
"""Retrieve memories based on a query.

Steps:
1. Convert query to embedding
2. Find relevant memories
3. Possibly rerank or filter results
"""
embedText = EmbedText(text=query, embedding_vector=await self._get_query_embedding(query))
return await self.storage.retrieve_memories(embedText, user_id, limit)
if query:
text_embedding = TextEmbedding(text=query, embedding_vector=await self._get_query_embedding(query))
else:
text_embedding = None

return await self.storage.retrieve_memories(
user_id=user_id, text_embedding=text_embedding, topics=topics, limit=limit
)

async def update_memory(self, memory_id: str, updated_memory: str) -> None:
metadata = await self._extract_metadata_from_fact(updated_memory)
Expand Down Expand Up @@ -195,10 +225,13 @@ async def remove_messages(self, message_ids: List[str]) -> None:
logger.info("messages {} are removed".format(",".join(message_ids)))

async def _get_add_memory_processing_decision(
self, new_memory_fact: str, user_id: Optional[str]
self, new_memory_fact: SemanticFact, user_id: Optional[str]
) -> ProcessSemanticMemoryDecision:
similar_memories = await self.retrieve_memories(new_memory_fact, user_id, None)
decision = await self._extract_memory_processing_decision(new_memory_fact, similar_memories, user_id)
# topics = (
# [topic for topic in self.topics if topic.name in new_memory_fact.topics] if new_memory_fact.topics else None # noqa: E501
# )
similar_memories = await self._retrieve_memories(user_id, new_memory_fact.text, None, None)
decision = await self._extract_memory_processing_decision(new_memory_fact.text, similar_memories, user_id)
return decision

async def _extract_memory_processing_decision(
Expand Down Expand Up @@ -306,6 +339,9 @@ async def _extract_semantic_fact_from_messages(
else:
# we explicitly ignore internal messages
continue
topics = "\n".join(
[f"<MEMORY_TOPIC NAME={topic.name}>{topic.description}</MEMORY_TOPIC>" for topic in self.topics]
)

existing_memories_str = ""
if existing_memories:
Expand All @@ -318,11 +354,7 @@ async def _extract_semantic_fact_from_messages(
that will remain relevant over time, even if the user is mentioning short-term plans or events.

Prioritize:
- General Interests and Preferences: When a user mentions specific events or actions, focus on the underlying
interests, hobbies, or preferences they reveal (e.g., if the user mentions attending a conference, focus on the topic of the conference,
not the date or location).
- Facts or Details about user: Extract facts that describe relevant information about the user, such as details about things they own.
- Facts about the user that the assistant might find useful.
{topics}

Avoid:
- Extraction memories that already exist in the system. If a fact is already stored, ignore it.
Expand All @@ -335,7 +367,6 @@ async def _extract_semantic_fact_from_messages(
{messages_str}
</TRANSCRIPT>
""" # noqa: E501

llm_messages = [
{"role": "system", "content": system_message},
{
Expand Down
18 changes: 14 additions & 4 deletions packages/memory_module/core/memory_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
from memory_module.interfaces.base_memory_core import BaseMemoryCore
from memory_module.interfaces.base_memory_module import BaseMemoryModule
from memory_module.interfaces.base_message_queue import BaseMessageQueue
from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig
from memory_module.interfaces.types import (
Memory,
Message,
MessageInput,
RetrievalConfig,
ShortTermMemoryRetrievalConfig,
)
from memory_module.services.llm_service import LLMService
from memory_module.utils.logging import configure_logging

Expand Down Expand Up @@ -50,10 +56,14 @@ async def add_message(self, message: MessageInput) -> Message:
await self.message_queue.enqueue(message_res)
return message_res

async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]:
async def retrieve_memories(
self,
user_id: Optional[str],
config: RetrievalConfig,
) -> List[Memory]:
"""Retrieve relevant memories based on a query."""
logger.debug(f"retrieve memories from (query: {query}, user_id: {user_id}, limit: {limit})")
memories = await self.memory_core.retrieve_memories(query, user_id, limit)
logger.debug(f"retrieve memories from (query: {config.query}, user_id: {user_id}, limit: {config.limit})")
memories = await self.memory_core.retrieve_memories(user_id=user_id, config=config)
logger.debug(f"retrieved memories: {memories}")
return memories

Expand Down
14 changes: 12 additions & 2 deletions packages/memory_module/interfaces/base_memory_core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional

from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig
from memory_module.interfaces.types import (
Memory,
Message,
MessageInput,
RetrievalConfig,
ShortTermMemoryRetrievalConfig,
)


class BaseMemoryCore(ABC):
Expand All @@ -22,7 +28,11 @@ async def process_episodic_messages(self, messages: List[Message]) -> None:
pass

@abstractmethod
async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]:
async def retrieve_memories(
self,
user_id: Optional[str],
config: RetrievalConfig,
) -> List[Memory]:
"""Retrieve memories based on a query."""
pass

Expand Down
14 changes: 12 additions & 2 deletions packages/memory_module/interfaces/base_memory_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional

from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig
from memory_module.interfaces.types import (
Memory,
Message,
MessageInput,
RetrievalConfig,
ShortTermMemoryRetrievalConfig,
)


class BaseMemoryModule(ABC):
Expand All @@ -13,7 +19,11 @@ async def add_message(self, message: MessageInput) -> Message:
pass

@abstractmethod
async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]:
async def retrieve_memories(
self,
user_id: Optional[str],
config: RetrievalConfig,
) -> List[Memory]:
"""Retrieve relevant memories based on a query."""
pass

Expand Down
10 changes: 8 additions & 2 deletions packages/memory_module/interfaces/base_memory_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

from memory_module.interfaces.types import (
BaseMemoryInput,
EmbedText,
Memory,
Message,
MessageInput,
ShortTermMemoryRetrievalConfig,
TextEmbedding,
Topic,
)


Expand Down Expand Up @@ -47,7 +48,12 @@ async def store_short_term_memory(self, message: MessageInput) -> Message:

@abstractmethod
async def retrieve_memories(
self, embedText: EmbedText, user_id: Optional[str], limit: Optional[int] = None
self,
*,
user_id: Optional[str],
text_embedding: Optional[TextEmbedding] = None,
topics: Optional[List[Topic]] = None,
limit: Optional[int] = None,
) -> List[Memory]:
"""Retrieve memories based on a query.

Expand Down
38 changes: 36 additions & 2 deletions packages/memory_module/interfaces/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ class BaseMemoryInput(BaseModel):
memory_type: MemoryType
user_id: Optional[str] = None
message_attributions: Optional[List[str]] = Field(default=[])
topics: Optional[List[str]] = None


class Topic(BaseModel):
name: str = Field(description="A unique name of the topic that the memory module should listen to")
description: str = Field(description="Description of the topic")


class Memory(BaseMemoryInput):
Expand All @@ -112,12 +118,40 @@ class Memory(BaseMemoryInput):
id: str


class EmbedText(BaseModel):
class TextEmbedding(BaseModel):
text: str
embedding_vector: List[float]


class ShortTermMemoryRetrievalConfig(BaseModel):
class RetrievalConfig(BaseModel):
"""Configuration for memory retrieval operations.

This class defines the parameters used to retrieve memories from storage. Memories can be
retrieved either by a semantic search query or by filtering for a specific topic or both.

In case of both, the memories are retrieved by the intersection of the two sets.
"""

query: Optional[str] = Field(
default=None, description="A natural language query to search for semantically similar memories"
)
topic: Optional[Topic] = Field(
default=None,
description="Topic to filter memories by. Only memories tagged with this topic will be retrieved",
)
limit: Optional[int] = Field(
default=None,
description="Maximum number of memories to retrieve. If not specified, all matching memories are returned",
)

@model_validator(mode="after")
def check_parameters(self) -> "RetrievalConfig":
if self.query is None and self.topic is None:
raise ValueError("Either query or topic must be provided")
return self


class ShortTermMemoryRetrievalConfig(RetrievalConfig):
n_messages: Optional[int] = None # Number of messages to retrieve
last_minutes: Optional[float] = None # Time frame in minutes
before: Optional[datetime] = None # Retrieve messages up until a specific timestamp
Expand Down
Loading
Loading