|
| 1 | +import asyncio |
| 2 | +import sys |
| 3 | +from datetime import datetime, timedelta |
| 4 | +from pathlib import Path |
| 5 | +from typing import cast |
| 6 | +from uuid import uuid4 |
| 7 | + |
| 8 | +from memory_module import MemoryModule, MemoryModuleConfig, UserMessageInput |
| 9 | +from memory_module.core.message_buffer import MessageBuffer |
| 10 | +from memory_module.core.message_queue import MessageQueue |
| 11 | +from memory_module.core.scheduler import Scheduler |
| 12 | +from tqdm import tqdm |
| 13 | + |
| 14 | +from tests.memory_module.utils import build_llm_config |
| 15 | + |
| 16 | +# Test cases from before |
| 17 | +TEST_CASES = [ |
| 18 | + { |
| 19 | + "title": "General vs. Specific Detail", |
| 20 | + "old_messages": ["I love outdoor activities.", "I often visit national parks."], |
| 21 | + "incoming_message": "I enjoy hiking in Yellowstone National Park.", |
| 22 | + "expected_decision": "ignore", |
| 23 | + "reason": "The old messages already cover the new message’s context.", |
| 24 | + }, |
| 25 | + { |
| 26 | + "title": "Specific Detail vs. General", |
| 27 | + "old_messages": ["I really enjoy hiking in Yellowstone National Park.", "I like exploring scenic trails."], |
| 28 | + "incoming_message": "I enjoy hiking in national parks.", |
| 29 | + "expected_decision": "ignore", |
| 30 | + "reason": "The new message is broader and redundant to the old messages.", |
| 31 | + }, |
| 32 | + { |
| 33 | + "title": "Repeated Behavior Over Time", |
| 34 | + "old_messages": ["I had coffee at 8 AM yesterday.", "I had coffee at 8 AM this morning."], |
| 35 | + "incoming_message": "I had coffee at 8 AM again today.", |
| 36 | + "expected_decision": "add", |
| 37 | + "reason": "This reinforces a recurring pattern of behavior over time.", |
| 38 | + }, |
| 39 | + { |
| 40 | + "title": "Updated Temporal Context", |
| 41 | + "old_messages": ["I’m planning a trip to Japan.", "I’ve been looking at flights to Japan."], |
| 42 | + "incoming_message": "I just canceled my trip to Japan.", |
| 43 | + "expected_decision": "add", |
| 44 | + "reason": "The new message reflects a significant update to the old messages.", |
| 45 | + }, |
| 46 | + { |
| 47 | + "title": "Irrelevant or Unnecessary Update", |
| 48 | + "old_messages": ["I prefer tea over coffee.", "I usually drink tea every day."], |
| 49 | + "incoming_message": "I like tea.", |
| 50 | + "expected_decision": "ignore", |
| 51 | + "reason": "The new message does not add any unique or relevant information.", |
| 52 | + }, |
| 53 | + { |
| 54 | + "title": "Redundant Memory with Different Wording", |
| 55 | + "old_messages": ["I have an iPhone 12.", "I bought an iPhone 12 back in 2022."], |
| 56 | + "incoming_message": "I own an iPhone 12.", |
| 57 | + "expected_decision": "ignore", |
| 58 | + "reason": "The new message is a rephrased duplicate of the old messages.", |
| 59 | + }, |
| 60 | + { |
| 61 | + "title": "Additional Specific Information", |
| 62 | + "old_messages": ["I like playing video games.", "I often play games on my console."], |
| 63 | + "incoming_message": "I love playing RPG video games like Final Fantasy.", |
| 64 | + "expected_decision": "add", |
| 65 | + "reason": "The new message adds specific details about the type of games.", |
| 66 | + }, |
| 67 | + { |
| 68 | + "title": "Contradictory Information", |
| 69 | + "old_messages": ["I like cats.", "I have a cat named Whiskers."], |
| 70 | + "incoming_message": "Actually, I don’t like cats.", |
| 71 | + "expected_decision": "add", |
| 72 | + "reason": "The new message reflects a contradiction or change in preference.", |
| 73 | + }, |
| 74 | + { |
| 75 | + "title": "New Memory Completely Unrelated", |
| 76 | + "old_messages": ["I love reading mystery novels.", "I’m a big fan of Agatha Christie’s books."], |
| 77 | + "incoming_message": "I really enjoy playing soccer.", |
| 78 | + "expected_decision": "add", |
| 79 | + "reason": "The new message introduces entirely new information.", |
| 80 | + }, |
| 81 | + { |
| 82 | + "title": "Multiple Old Messages with Partial Overlap", |
| 83 | + "old_messages": ["I have a car.", "My car is a Toyota Camry."], |
| 84 | + "incoming_message": "I own a blue Toyota Camry.", |
| 85 | + "expected_decision": "add", |
| 86 | + "reason": "The new message adds a specific detail (color) not covered by the old messages.", |
| 87 | + }, |
| 88 | +] |
| 89 | + |
| 90 | + |
| 91 | +async def evaluate_decision(memory_module, test_case): |
| 92 | + """Evaluate a single decision test case.""" |
| 93 | + conversation_id = str(uuid4()) |
| 94 | + |
| 95 | + # Add old messages |
| 96 | + for message_content in test_case["old_messages"]: |
| 97 | + message = UserMessageInput( |
| 98 | + id=str(uuid4()), |
| 99 | + content=message_content, |
| 100 | + author_id="user-123", |
| 101 | + conversation_ref=conversation_id, |
| 102 | + created_at=datetime.now() - timedelta(days=1), |
| 103 | + ) |
| 104 | + await memory_module.add_message(message) |
| 105 | + |
| 106 | + await memory_module.message_queue.message_buffer.scheduler.flush() |
| 107 | + |
| 108 | + # Create incoming message |
| 109 | + new_message = [ |
| 110 | + UserMessageInput( |
| 111 | + id=str(uuid4()), |
| 112 | + content=test_case["incoming_message"], |
| 113 | + author_id="user-123", |
| 114 | + conversation_ref=conversation_id, |
| 115 | + created_at=datetime.now(), |
| 116 | + ) |
| 117 | + ] |
| 118 | + |
| 119 | + # Get the decision |
| 120 | + extraction = await memory_module.memory_core._extract_semantic_fact_from_messages(new_message) |
| 121 | + if not (extraction.action == "add" and extraction.facts): |
| 122 | + return { |
| 123 | + "success": False, |
| 124 | + "error": "Failed to extract semantic facts", |
| 125 | + "test_case": test_case, |
| 126 | + "expected": test_case["expected_decision"], |
| 127 | + "got": "failed_extraction", |
| 128 | + "reason": "Failed to extract semantic facts", |
| 129 | + } |
| 130 | + |
| 131 | + for fact in extraction.facts: |
| 132 | + decision = await memory_module.memory_core._get_add_memory_processing_decision(fact.text, "user-123") |
| 133 | + return { |
| 134 | + "success": decision.decision == test_case["expected_decision"], |
| 135 | + "expected": test_case["expected_decision"], |
| 136 | + "got": decision.decision, |
| 137 | + "reason": decision.reason_for_decision, |
| 138 | + "test_case": test_case, |
| 139 | + } |
| 140 | + |
| 141 | + |
| 142 | +async def main(): |
| 143 | + # Initialize config and memory module |
| 144 | + llm_config = build_llm_config() |
| 145 | + if not llm_config.api_key: |
| 146 | + print("Error: OpenAI API key not provided") |
| 147 | + sys.exit(1) |
| 148 | + |
| 149 | + db_path = Path(__file__).parent / "data" / "evaluation" / "memory_module.db" |
| 150 | + # Create db directory if it doesn't exist |
| 151 | + db_path.parent.mkdir(parents=True, exist_ok=True) |
| 152 | + config = MemoryModuleConfig( |
| 153 | + db_path=db_path, |
| 154 | + buffer_size=5, |
| 155 | + timeout_seconds=60, |
| 156 | + llm=llm_config, |
| 157 | + ) |
| 158 | + |
| 159 | + # Delete existing db if it exists |
| 160 | + if db_path.exists(): |
| 161 | + db_path.unlink() |
| 162 | + |
| 163 | + memory_module = MemoryModule(config=config) |
| 164 | + |
| 165 | + results = [] |
| 166 | + successes = 0 |
| 167 | + failures = 0 |
| 168 | + |
| 169 | + # Run evaluations with progress bar |
| 170 | + print("\nEvaluating memory processing decisions...") |
| 171 | + for test_case in tqdm(TEST_CASES, desc="Processing test cases"): |
| 172 | + result = await evaluate_decision(memory_module, test_case) |
| 173 | + results.append(result) |
| 174 | + if result["success"]: |
| 175 | + successes += 1 |
| 176 | + else: |
| 177 | + failures += 1 |
| 178 | + |
| 179 | + # Calculate statistics |
| 180 | + total = len(TEST_CASES) |
| 181 | + success_rate = (successes / total) * 100 |
| 182 | + |
| 183 | + # Print summary |
| 184 | + print("\n=== Evaluation Summary ===") |
| 185 | + print(f"Total test cases: {total}") |
| 186 | + print(f"Successes: {successes} ({success_rate:.1f}%)") |
| 187 | + print(f"Failures: {failures} ({100 - success_rate:.1f}%)") |
| 188 | + |
| 189 | + # Print detailed failures if any |
| 190 | + if failures > 0: |
| 191 | + print("\n=== Failed Cases ===") |
| 192 | + for result in results: |
| 193 | + if not result["success"]: |
| 194 | + test_case = result["test_case"] |
| 195 | + print(f"\nTest Case: {test_case['title']}") |
| 196 | + print(f"Reason: {test_case['reason']}") |
| 197 | + print(f"Actual result: {result['reason']}") |
| 198 | + print(f"Expected: {result['expected']}") |
| 199 | + print(f"Got: {result['got']}") |
| 200 | + print("Old messages:") |
| 201 | + for msg in test_case["old_messages"]: |
| 202 | + print(f" - {msg}") |
| 203 | + print(f"New message: {test_case['incoming_message']}") |
| 204 | + print("-" * 50) |
| 205 | + |
| 206 | + # Cleanup |
| 207 | + message_queue = cast(MessageQueue, memory_module.message_queue) |
| 208 | + message_buffer = cast(MessageBuffer, message_queue.message_buffer) |
| 209 | + scheduler = cast(Scheduler, message_buffer.scheduler) |
| 210 | + await scheduler.cleanup() |
| 211 | + |
| 212 | + |
| 213 | +if __name__ == "__main__": |
| 214 | + asyncio.run(main()) |
0 commit comments