Skip to content

Commit 83d0e9a

Browse files
committed
Topics
1 parent 97e44fe commit 83d0e9a

17 files changed

+705
-251
lines changed

.pre-commit-config.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@ repos:
1313
rev: v1.13.0
1414
hooks:
1515
- id: mypy
16+
pass_filenames: false
17+
args: [--ignore-missing-imports, --show-traceback]

packages/evals/benchmark_memory_module.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
from memory_module.config import LLMConfig, MemoryModuleConfig
1818
from memory_module.core.memory_module import MemoryModule
19-
from memory_module.interfaces.types import AssistantMessage, UserMessage
19+
from memory_module.interfaces.types import AssistantMessage, RetrievalConfig, UserMessage
2020

21-
from evals.helpers import Dataset, DatasetItem, load_dataset, setup_mlflow
21+
from evals.helpers import Dataset, DatasetItem, SessionMessage, load_dataset, setup_mlflow
2222
from evals.metrics import string_check_metric
2323

2424
setup_mlflow(experiment_name="memory_module")
@@ -48,7 +48,7 @@ def __exit__(self, exc_type, exc_value, traceback):
4848
os.remove(self._db_path)
4949

5050

51-
async def add_messages(memory_module: MemoryModule, messages: List[dict]):
51+
async def add_messages(memory_module: MemoryModule, messages: List[SessionMessage]):
5252
def create_message(**kwargs):
5353
params = {
5454
"id": str(uuid.uuid4()),
@@ -94,7 +94,7 @@ async def benchmark_memory_module(input: DatasetItem):
9494
# buffer size has to be the same as the session length to trigger sm processing
9595
with MemoryModuleManager(buffer_size=len(session)) as memory_module:
9696
await add_messages(memory_module, messages=session)
97-
memories = await memory_module.retrieve_memories(query, user_id=None, limit=None)
97+
memories = await memory_module.retrieve_memories(None, RetrievalConfig(query=query, limit=None))
9898

9999
return {
100100
"input": {

packages/memory_module/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Memory,
1010
Message,
1111
MessageInput,
12+
RetrievalConfig,
1213
ShortTermMemoryRetrievalConfig,
1314
UserMessage,
1415
UserMessageInput,
@@ -29,6 +30,7 @@
2930
"MessageInput",
3031
"AssistantMessage",
3132
"AssistantMessageInput",
33+
"RetrievalConfig",
3234
"ShortTermMemoryRetrievalConfig",
3335
"MemoryMiddleware",
3436
]

packages/memory_module/config.py

+17
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from pydantic import BaseModel, ConfigDict, Field
55

6+
from memory_module.interfaces.types import Topic
7+
68

79
class LLMConfig(BaseModel):
810
"""Configuration for LLM service."""
@@ -16,6 +18,18 @@ class LLMConfig(BaseModel):
1618
embedding_model: Optional[str] = None
1719

1820

21+
DEFAULT_TOPICS = [
22+
Topic(
23+
name="General Interests and Preferences",
24+
description="When a user mentions specific events or actions, focus on the underlying interests, hobbies, or preferences they reveal (e.g., if the user mentions attending a conference, focus on the topic of the conference, not the date or location).", # noqa: E501
25+
),
26+
Topic(
27+
name="General Facts about the user",
28+
description="Facts that describe relevant information about the user, such as details about where they live or things they own.", # noqa: E501
29+
),
30+
]
31+
32+
1933
class MemoryModuleConfig(BaseModel):
2034
"""Configuration for memory module components.
2135
@@ -35,4 +49,7 @@ class MemoryModuleConfig(BaseModel):
3549
description="Seconds to wait before processing a conversation",
3650
)
3751
llm: LLMConfig = Field(description="LLM service configuration")
52+
topics: list[Topic] = Field(
53+
default=DEFAULT_TOPICS, description="List of topics that the memory module should listen to", min_length=1
54+
)
3855
enable_logging: bool = Field(default=False, description="Enable verbose logging for memory module")

packages/memory_module/core/memory_core.py

+46-13
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
from memory_module.interfaces.base_memory_storage import BaseMemoryStorage
1111
from memory_module.interfaces.types import (
1212
BaseMemoryInput,
13-
EmbedText,
1413
Memory,
1514
MemoryType,
1615
Message,
1716
MessageInput,
17+
RetrievalConfig,
1818
ShortTermMemoryRetrievalConfig,
19+
TextEmbedding,
20+
Topic,
1921
)
2022
from memory_module.services.llm_service import LLMService
2123
from memory_module.storage.in_memory_storage import InMemoryStorage
@@ -51,6 +53,11 @@ class SemanticFact(BaseModel):
5153
default_factory=set,
5254
description="The indices of the messages that the fact was extracted from.",
5355
)
56+
# TODO: Add a validator to ensure that topics are valid
57+
topics: Optional[List[str]] = Field(
58+
default=None,
59+
description="The name of the topic that the fact is most relevant to.", # noqa: E501
60+
)
5461

5562

5663
class SemanticMemoryExtraction(BaseModel):
@@ -106,6 +113,7 @@ def __init__(
106113
self.storage: BaseMemoryStorage = storage or (
107114
SQLiteMemoryStorage(db_path=config.db_path) if config.db_path is not None else InMemoryStorage()
108115
)
116+
self.topics = config.topics
109117

110118
async def process_semantic_messages(
111119
self,
@@ -145,6 +153,7 @@ async def process_semantic_messages(
145153
user_id=author_id,
146154
message_attributions=list(message_ids),
147155
memory_type=MemoryType.SEMANTIC,
156+
topics=fact.topics,
148157
)
149158
embed_vectors = await self._get_semantic_fact_embeddings(fact.text, metadata)
150159
await self.storage.store_memory(memory, embedding_vectors=embed_vectors)
@@ -154,16 +163,37 @@ async def process_episodic_messages(self, messages: List[Message]) -> None:
154163
# TODO: Implement episodic memory processing
155164
await self._extract_episodic_memory_from_messages(messages)
156165

157-
async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]:
166+
async def retrieve_memories(
167+
self,
168+
user_id: Optional[str],
169+
config: RetrievalConfig,
170+
) -> List[Memory]:
171+
return await self._retrieve_memories(
172+
user_id, config.query, config.topics if config.topics else None, config.limit
173+
)
174+
175+
async def _retrieve_memories(
176+
self,
177+
user_id: Optional[str],
178+
query: Optional[str],
179+
topics: Optional[List[Topic]],
180+
limit: Optional[int],
181+
) -> List[Memory]:
158182
"""Retrieve memories based on a query.
159183
160184
Steps:
161185
1. Convert query to embedding
162186
2. Find relevant memories
163187
3. Possibly rerank or filter results
164188
"""
165-
embedText = EmbedText(text=query, embedding_vector=await self._get_query_embedding(query))
166-
return await self.storage.retrieve_memories(embedText, user_id, limit)
189+
if query:
190+
text_embedding = TextEmbedding(text=query, embedding_vector=await self._get_query_embedding(query))
191+
else:
192+
text_embedding = None
193+
194+
return await self.storage.retrieve_memories(
195+
user_id=user_id, text_embedding=text_embedding, topics=topics, limit=limit
196+
)
167197

168198
async def update_memory(self, memory_id: str, updated_memory: str) -> None:
169199
metadata = await self._extract_metadata_from_fact(updated_memory)
@@ -195,10 +225,15 @@ async def remove_messages(self, message_ids: List[str]) -> None:
195225
logger.info("messages {} are removed".format(",".join(message_ids)))
196226

197227
async def _get_add_memory_processing_decision(
198-
self, new_memory_fact: str, user_id: Optional[str]
228+
self, new_memory_fact: SemanticFact, user_id: Optional[str]
199229
) -> ProcessSemanticMemoryDecision:
200-
similar_memories = await self.retrieve_memories(new_memory_fact, user_id, None)
201-
decision = await self._extract_memory_processing_decision(new_memory_fact, similar_memories, user_id)
230+
topics = (
231+
[topic for topic in self.topics if topic.name in new_memory_fact.topics] if new_memory_fact.topics else None
232+
)
233+
similar_memories = await self.retrieve_memories(
234+
user_id, RetrievalConfig(query=new_memory_fact.text, topics=topics, limit=None)
235+
)
236+
decision = await self._extract_memory_processing_decision(new_memory_fact.text, similar_memories, user_id)
202237
return decision
203238

204239
async def _extract_memory_processing_decision(
@@ -306,6 +341,9 @@ async def _extract_semantic_fact_from_messages(
306341
else:
307342
# we explicitly ignore internal messages
308343
continue
344+
topics = "\n".join(
345+
[f"<MEMORY_TOPIC NAME={topic.name}>{topic.description}</MEMORY_TOPIC>" for topic in self.topics]
346+
)
309347

310348
existing_memories_str = ""
311349
if existing_memories:
@@ -318,11 +356,7 @@ async def _extract_semantic_fact_from_messages(
318356
that will remain relevant over time, even if the user is mentioning short-term plans or events.
319357
320358
Prioritize:
321-
- General Interests and Preferences: When a user mentions specific events or actions, focus on the underlying
322-
interests, hobbies, or preferences they reveal (e.g., if the user mentions attending a conference, focus on the topic of the conference,
323-
not the date or location).
324-
- Facts or Details about user: Extract facts that describe relevant information about the user, such as details about things they own.
325-
- Facts about the user that the assistant might find useful.
359+
{topics}
326360
327361
Avoid:
328362
- Extraction memories that already exist in the system. If a fact is already stored, ignore it.
@@ -335,7 +369,6 @@ async def _extract_semantic_fact_from_messages(
335369
{messages_str}
336370
</TRANSCRIPT>
337371
""" # noqa: E501
338-
339372
llm_messages = [
340373
{"role": "system", "content": system_message},
341374
{

packages/memory_module/core/memory_module.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
from memory_module.interfaces.base_memory_core import BaseMemoryCore
88
from memory_module.interfaces.base_memory_module import BaseMemoryModule
99
from memory_module.interfaces.base_message_queue import BaseMessageQueue
10-
from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig
10+
from memory_module.interfaces.types import (
11+
Memory,
12+
Message,
13+
MessageInput,
14+
RetrievalConfig,
15+
ShortTermMemoryRetrievalConfig,
16+
)
1117
from memory_module.services.llm_service import LLMService
1218
from memory_module.utils.logging import configure_logging
1319

@@ -50,10 +56,14 @@ async def add_message(self, message: MessageInput) -> Message:
5056
await self.message_queue.enqueue(message_res)
5157
return message_res
5258

53-
async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]:
59+
async def retrieve_memories(
60+
self,
61+
user_id: Optional[str],
62+
config: RetrievalConfig,
63+
) -> List[Memory]:
5464
"""Retrieve relevant memories based on a query."""
55-
logger.debug(f"retrieve memories from (query: {query}, user_id: {user_id}, limit: {limit})")
56-
memories = await self.memory_core.retrieve_memories(query, user_id, limit)
65+
logger.debug(f"retrieve memories from (query: {config.query}, user_id: {user_id}, limit: {config.limit})")
66+
memories = await self.memory_core.retrieve_memories(user_id=user_id, config=config)
5767
logger.debug(f"retrieved memories: {memories}")
5868
return memories
5969

packages/memory_module/interfaces/base_memory_core.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
from abc import ABC, abstractmethod
22
from typing import Dict, List, Optional
33

4-
from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig
4+
from memory_module.interfaces.types import (
5+
Memory,
6+
Message,
7+
MessageInput,
8+
RetrievalConfig,
9+
ShortTermMemoryRetrievalConfig,
10+
)
511

612

713
class BaseMemoryCore(ABC):
@@ -22,7 +28,11 @@ async def process_episodic_messages(self, messages: List[Message]) -> None:
2228
pass
2329

2430
@abstractmethod
25-
async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]:
31+
async def retrieve_memories(
32+
self,
33+
user_id: Optional[str],
34+
config: RetrievalConfig,
35+
) -> List[Memory]:
2636
"""Retrieve memories based on a query."""
2737
pass
2838

packages/memory_module/interfaces/base_memory_module.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
from abc import ABC, abstractmethod
22
from typing import Dict, List, Optional
33

4-
from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig
4+
from memory_module.interfaces.types import (
5+
Memory,
6+
Message,
7+
MessageInput,
8+
RetrievalConfig,
9+
ShortTermMemoryRetrievalConfig,
10+
)
511

612

713
class BaseMemoryModule(ABC):
@@ -13,7 +19,11 @@ async def add_message(self, message: MessageInput) -> Message:
1319
pass
1420

1521
@abstractmethod
16-
async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]:
22+
async def retrieve_memories(
23+
self,
24+
user_id: Optional[str],
25+
config: RetrievalConfig,
26+
) -> List[Memory]:
1727
"""Retrieve relevant memories based on a query."""
1828
pass
1929

packages/memory_module/interfaces/base_memory_storage.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44
from memory_module.interfaces.types import (
55
BaseMemoryInput,
6-
EmbedText,
76
Memory,
87
Message,
98
MessageInput,
109
ShortTermMemoryRetrievalConfig,
10+
TextEmbedding,
11+
Topic,
1112
)
1213

1314

@@ -47,7 +48,12 @@ async def store_short_term_memory(self, message: MessageInput) -> Message:
4748

4849
@abstractmethod
4950
async def retrieve_memories(
50-
self, embedText: EmbedText, user_id: Optional[str], limit: Optional[int] = None
51+
self,
52+
*,
53+
user_id: Optional[str],
54+
text_embedding: Optional[TextEmbedding] = None,
55+
topics: Optional[List[Topic]] = None,
56+
limit: Optional[int] = None,
5157
) -> List[Memory]:
5258
"""Retrieve memories based on a query.
5359

packages/memory_module/interfaces/types.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ class BaseMemoryInput(BaseModel):
104104
memory_type: MemoryType
105105
user_id: Optional[str] = None
106106
message_attributions: Optional[List[str]] = Field(default=[])
107+
topics: Optional[List[str]] = None
108+
109+
110+
class Topic(BaseModel):
111+
name: str = Field(description="A unique name of the topic that the memory module should listen to")
112+
description: str = Field(description="Description of the topic")
107113

108114

109115
class Memory(BaseMemoryInput):
@@ -112,12 +118,24 @@ class Memory(BaseMemoryInput):
112118
id: str
113119

114120

115-
class EmbedText(BaseModel):
121+
class TextEmbedding(BaseModel):
116122
text: str
117123
embedding_vector: List[float]
118124

119125

120-
class ShortTermMemoryRetrievalConfig(BaseModel):
126+
class RetrievalConfig(BaseModel):
127+
query: Optional[str] = None
128+
topic: Optional[Topic] = None
129+
limit: Optional[int] = None
130+
131+
@model_validator(mode="after")
132+
def check_parameters(self) -> "RetrievalConfig":
133+
if self.query is None and self.topic is None:
134+
raise ValueError("Either query or topic must be provided")
135+
return self
136+
137+
138+
class ShortTermMemoryRetrievalConfig(RetrievalConfig):
121139
n_messages: Optional[int] = None # Number of messages to retrieve
122140
last_minutes: Optional[float] = None # Time frame in minutes
123141
before: Optional[datetime] = None # Retrieve messages up until a specific timestamp

0 commit comments

Comments
 (0)