Skip to content

Commit 0d57654

Browse files
committed
Add tests
1 parent b58a3eb commit 0d57654

File tree

3 files changed

+224
-6
lines changed

3 files changed

+224
-6
lines changed

packages/memory_module/core/memory_core.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ async def process_semantic_messages(
134134
if extraction.action == "add" and extraction.facts:
135135
for fact in extraction.facts:
136136
decision = await self._get_add_memory_processing_decision(fact.text, author_id)
137-
if decision == "ignore":
137+
if decision.decision == "ignore":
138138
logger.info(f"Decision to ignore fact {fact.text}")
139139
continue
140140
metadata = await self._extract_metadata_from_fact(fact.text)
@@ -194,10 +194,12 @@ async def remove_messages(self, message_ids: List[str]) -> None:
194194
await self.storage.remove_messages(message_ids)
195195
logger.info("messages {} are removed".format(",".join(message_ids)))
196196

197-
async def _get_add_memory_processing_decision(self, new_memory_fact: str, user_id: Optional[str]) -> str:
197+
async def _get_add_memory_processing_decision(
198+
self, new_memory_fact: str, user_id: Optional[str]
199+
) -> ProcessSemanticMemoryDecision:
198200
similar_memories = await self.retrieve_memories(new_memory_fact, user_id, None)
199201
decision = await self._extract_memory_processing_decision(new_memory_fact, similar_memories, user_id)
200-
return decision.decision
202+
return decision
201203

202204
async def _extract_memory_processing_decision(
203205
self, new_memory: str, old_memories: List[Memory], user_id: Optional[str]

scripts/evaluate_memory_decisions.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import asyncio
2+
import sys
3+
from datetime import datetime, timedelta
4+
from pathlib import Path
5+
from typing import cast
6+
from uuid import uuid4
7+
8+
from memory_module import MemoryModule, MemoryModuleConfig, UserMessageInput
9+
from memory_module.core.message_buffer import MessageBuffer
10+
from memory_module.core.message_queue import MessageQueue
11+
from memory_module.core.scheduler import Scheduler
12+
from tqdm import tqdm
13+
14+
from tests.memory_module.utils import build_llm_config
15+
16+
# Test cases from before
17+
TEST_CASES = [
18+
{
19+
"title": "General vs. Specific Detail",
20+
"old_messages": ["I love outdoor activities.", "I often visit national parks."],
21+
"incoming_message": "I enjoy hiking in Yellowstone National Park.",
22+
"expected_decision": "ignore",
23+
"reason": "The old messages already cover the new message’s context.",
24+
},
25+
{
26+
"title": "Specific Detail vs. General",
27+
"old_messages": ["I really enjoy hiking in Yellowstone National Park.", "I like exploring scenic trails."],
28+
"incoming_message": "I enjoy hiking in national parks.",
29+
"expected_decision": "ignore",
30+
"reason": "The new message is broader and redundant to the old messages.",
31+
},
32+
{
33+
"title": "Repeated Behavior Over Time",
34+
"old_messages": ["I had coffee at 8 AM yesterday.", "I had coffee at 8 AM this morning."],
35+
"incoming_message": "I had coffee at 8 AM again today.",
36+
"expected_decision": "add",
37+
"reason": "This reinforces a recurring pattern of behavior over time.",
38+
},
39+
{
40+
"title": "Updated Temporal Context",
41+
"old_messages": ["I’m planning a trip to Japan.", "I’ve been looking at flights to Japan."],
42+
"incoming_message": "I just canceled my trip to Japan.",
43+
"expected_decision": "add",
44+
"reason": "The new message reflects a significant update to the old messages.",
45+
},
46+
{
47+
"title": "Irrelevant or Unnecessary Update",
48+
"old_messages": ["I prefer tea over coffee.", "I usually drink tea every day."],
49+
"incoming_message": "I like tea.",
50+
"expected_decision": "ignore",
51+
"reason": "The new message does not add any unique or relevant information.",
52+
},
53+
{
54+
"title": "Redundant Memory with Different Wording",
55+
"old_messages": ["I have an iPhone 12.", "I bought an iPhone 12 back in 2022."],
56+
"incoming_message": "I own an iPhone 12.",
57+
"expected_decision": "ignore",
58+
"reason": "The new message is a rephrased duplicate of the old messages.",
59+
},
60+
{
61+
"title": "Additional Specific Information",
62+
"old_messages": ["I like playing video games.", "I often play games on my console."],
63+
"incoming_message": "I love playing RPG video games like Final Fantasy.",
64+
"expected_decision": "add",
65+
"reason": "The new message adds specific details about the type of games.",
66+
},
67+
{
68+
"title": "Contradictory Information",
69+
"old_messages": ["I like cats.", "I have a cat named Whiskers."],
70+
"incoming_message": "Actually, I don’t like cats.",
71+
"expected_decision": "add",
72+
"reason": "The new message reflects a contradiction or change in preference.",
73+
},
74+
{
75+
"title": "New Memory Completely Unrelated",
76+
"old_messages": ["I love reading mystery novels.", "I’m a big fan of Agatha Christie’s books."],
77+
"incoming_message": "I really enjoy playing soccer.",
78+
"expected_decision": "add",
79+
"reason": "The new message introduces entirely new information.",
80+
},
81+
{
82+
"title": "Multiple Old Messages with Partial Overlap",
83+
"old_messages": ["I have a car.", "My car is a Toyota Camry."],
84+
"incoming_message": "I own a blue Toyota Camry.",
85+
"expected_decision": "add",
86+
"reason": "The new message adds a specific detail (color) not covered by the old messages.",
87+
},
88+
]
89+
90+
91+
async def evaluate_decision(memory_module, test_case):
92+
"""Evaluate a single decision test case."""
93+
conversation_id = str(uuid4())
94+
95+
# Add old messages
96+
for message_content in test_case["old_messages"]:
97+
message = UserMessageInput(
98+
id=str(uuid4()),
99+
content=message_content,
100+
author_id="user-123",
101+
conversation_ref=conversation_id,
102+
created_at=datetime.now() - timedelta(days=1),
103+
)
104+
await memory_module.add_message(message)
105+
106+
await memory_module.message_queue.message_buffer.scheduler.flush()
107+
108+
# Create incoming message
109+
new_message = [
110+
UserMessageInput(
111+
id=str(uuid4()),
112+
content=test_case["incoming_message"],
113+
author_id="user-123",
114+
conversation_ref=conversation_id,
115+
created_at=datetime.now(),
116+
)
117+
]
118+
119+
# Get the decision
120+
extraction = await memory_module.memory_core._extract_semantic_fact_from_messages(new_message)
121+
if not (extraction.action == "add" and extraction.facts):
122+
return {
123+
"success": False,
124+
"error": "Failed to extract semantic facts",
125+
"test_case": test_case,
126+
"expected": test_case["expected_decision"],
127+
"got": "failed_extraction",
128+
"reason": "Failed to extract semantic facts",
129+
}
130+
131+
for fact in extraction.facts:
132+
decision = await memory_module.memory_core._get_add_memory_processing_decision(fact.text, "user-123")
133+
return {
134+
"success": decision.decision == test_case["expected_decision"],
135+
"expected": test_case["expected_decision"],
136+
"got": decision.decision,
137+
"reason": decision.reason_for_decision,
138+
"test_case": test_case,
139+
}
140+
141+
142+
async def main():
143+
# Initialize config and memory module
144+
llm_config = build_llm_config()
145+
if not llm_config.api_key:
146+
print("Error: OpenAI API key not provided")
147+
sys.exit(1)
148+
149+
db_path = Path(__file__).parent / "data" / "evaluation" / "memory_module.db"
150+
# Create db directory if it doesn't exist
151+
db_path.parent.mkdir(parents=True, exist_ok=True)
152+
config = MemoryModuleConfig(
153+
db_path=db_path,
154+
buffer_size=5,
155+
timeout_seconds=60,
156+
llm=llm_config,
157+
)
158+
159+
# Delete existing db if it exists
160+
if db_path.exists():
161+
db_path.unlink()
162+
163+
memory_module = MemoryModule(config=config)
164+
165+
results = []
166+
successes = 0
167+
failures = 0
168+
169+
# Run evaluations with progress bar
170+
print("\nEvaluating memory processing decisions...")
171+
for test_case in tqdm(TEST_CASES, desc="Processing test cases"):
172+
result = await evaluate_decision(memory_module, test_case)
173+
results.append(result)
174+
if result["success"]:
175+
successes += 1
176+
else:
177+
failures += 1
178+
179+
# Calculate statistics
180+
total = len(TEST_CASES)
181+
success_rate = (successes / total) * 100
182+
183+
# Print summary
184+
print("\n=== Evaluation Summary ===")
185+
print(f"Total test cases: {total}")
186+
print(f"Successes: {successes} ({success_rate:.1f}%)")
187+
print(f"Failures: {failures} ({100 - success_rate:.1f}%)")
188+
189+
# Print detailed failures if any
190+
if failures > 0:
191+
print("\n=== Failed Cases ===")
192+
for result in results:
193+
if not result["success"]:
194+
test_case = result["test_case"]
195+
print(f"\nTest Case: {test_case['title']}")
196+
print(f"Reason: {test_case['reason']}")
197+
print(f"Actual result: {result['reason']}")
198+
print(f"Expected: {result['expected']}")
199+
print(f"Got: {result['got']}")
200+
print("Old messages:")
201+
for msg in test_case["old_messages"]:
202+
print(f" - {msg}")
203+
print(f"New message: {test_case['incoming_message']}")
204+
print("-" * 50)
205+
206+
# Cleanup
207+
message_queue = cast(MessageQueue, memory_module.message_queue)
208+
message_buffer = cast(MessageBuffer, message_queue.message_buffer)
209+
scheduler = cast(Scheduler, message_buffer.scheduler)
210+
await scheduler.cleanup()
211+
212+
213+
if __name__ == "__main__":
214+
asyncio.run(main())

tests/memory_module/test_memory_module.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,10 @@ async def _mock_embedding(**kwargs):
116116
@pytest_asyncio.fixture(autouse=True)
117117
async def cleanup_scheduled_events(memory_module):
118118
"""Fixture to cleanup scheduled events after each test."""
119-
yield
120-
await memory_module.message_queue.message_buffer.scheduler.cleanup()
119+
try:
120+
yield
121+
finally:
122+
await memory_module.message_queue.message_buffer.scheduler.cleanup()
121123

122124

123125
@pytest.mark.asyncio
@@ -342,7 +344,7 @@ async def _validate_decision(memory_module, message: List[UserMessageInput], exp
342344
assert extraction.action == "add" and extraction.facts
343345
for fact in extraction.facts:
344346
decision = await memory_module.memory_core._get_add_memory_processing_decision(fact.text, "user-123")
345-
assert decision == expected_decision
347+
assert decision.decision == expected_decision
346348

347349

348350
@pytest.mark.asyncio

0 commit comments

Comments
 (0)