From 91d06eeb49dc203f3618ae50c5b59c17dd6f0275 Mon Sep 17 00:00:00 2001 From: Kimi Agent Date: Tue, 24 Mar 2026 03:08:36 +0000 Subject: [PATCH] [kimi] Add unit tests for memory/crud.py (#1344) (#1358) --- tests/timmy/test_memory_crud.py | 889 ++++++++++++++++++++++++++++++++ 1 file changed, 889 insertions(+) create mode 100644 tests/timmy/test_memory_crud.py diff --git a/tests/timmy/test_memory_crud.py b/tests/timmy/test_memory_crud.py new file mode 100644 index 00000000..d25ff6f5 --- /dev/null +++ b/tests/timmy/test_memory_crud.py @@ -0,0 +1,889 @@ +"""Unit tests for timmy.memory.crud — Memory CRUD operations.""" + +import json +from datetime import UTC, datetime, timedelta +from unittest.mock import patch + +import pytest + +from timmy.memory.crud import ( + _build_search_filters, + _fetch_memory_candidates, + _row_to_entry, + _score_and_filter, + delete_memory, + get_memory_context, + get_memory_stats, + prune_memories, + recall_last_reflection, + recall_personal_facts, + recall_personal_facts_with_ids, + search_memories, + store_last_reflection, + store_memory, + store_personal_fact, + update_personal_fact, +) +from timmy.memory.db import MemoryEntry, get_connection + + +@pytest.fixture +def tmp_db_path(tmp_path): + """Provide a temporary database path.""" + return tmp_path / "test_memory.db" + + +@pytest.fixture +def patched_db(tmp_db_path): + """Patch DB_PATH to use temporary database for tests.""" + with patch("timmy.memory.db.DB_PATH", tmp_db_path): + # Create schema + with get_connection(): + pass # Schema is created by get_connection + yield tmp_db_path + + +@pytest.fixture +def sample_entry(): + """Create a sample MemoryEntry for testing.""" + return MemoryEntry( + id="test-id-123", + content="Test memory content", + source="test", + context_type="conversation", + agent_id="agent-1", + task_id="task-1", + session_id="session-1", + metadata={"key": "value"}, + embedding=[0.1, 0.2, 0.3], + timestamp=datetime.now(UTC).isoformat(), + ) + + +@pytest.mark.unit +class TestStoreMemory: + """Tests for store_memory function.""" + + def test_store_memory_basic(self, patched_db): + """Test storing a basic memory entry.""" + entry = store_memory( + content="Hello world", + source="test", + context_type="conversation", + ) + + assert isinstance(entry, MemoryEntry) + assert entry.content == "Hello world" + assert entry.source == "test" + assert entry.context_type == "conversation" + assert entry.id is not None + + def test_store_memory_with_all_fields(self, patched_db): + """Test storing memory with all optional fields.""" + entry = store_memory( + content="Full test content", + source="user", + context_type="fact", + agent_id="agent-42", + task_id="task-99", + session_id="session-xyz", + metadata={"priority": "high", "tags": ["test"]}, + compute_embedding=False, + ) + + assert entry.content == "Full test content" + assert entry.agent_id == "agent-42" + assert entry.task_id == "task-99" + assert entry.session_id == "session-xyz" + assert entry.metadata == {"priority": "high", "tags": ["test"]} + assert entry.embedding is None + + def test_store_memory_with_embedding(self, patched_db): + """Test storing memory with embedding computation.""" + # TIMMY_SKIP_EMBEDDINGS=1 in conftest, so uses hash fallback + entry = store_memory( + content="Test content for embedding", + source="system", + compute_embedding=True, + ) + + assert entry.embedding is not None + assert isinstance(entry.embedding, list) + assert len(entry.embedding) == 128 # Hash embedding dimension + + def test_store_memory_without_embedding(self, patched_db): + """Test storing memory without embedding.""" + entry = store_memory( + content="No embedding needed", + source="test", + compute_embedding=False, + ) + + assert entry.embedding is None + + def test_store_memory_persists_to_db(self, patched_db): + """Test that stored memory is actually written to database.""" + entry = store_memory( + content="Persisted content", + source="db_test", + context_type="document", + ) + + # Verify directly in DB + with get_connection() as conn: + row = conn.execute( + "SELECT * FROM memories WHERE id = ?", + (entry.id,), + ).fetchone() + + assert row is not None + assert row["content"] == "Persisted content" + assert row["memory_type"] == "document" + assert row["source"] == "db_test" + + +@pytest.mark.unit +class TestBuildSearchFilters: + """Tests for _build_search_filters helper function.""" + + def test_no_filters(self): + """Test building filters with no criteria.""" + where_clause, params = _build_search_filters(None, None, None) + assert where_clause == "" + assert params == [] + + def test_context_type_filter(self): + """Test filter by context_type only.""" + where_clause, params = _build_search_filters("fact", None, None) + assert where_clause == "WHERE memory_type = ?" + assert params == ["fact"] + + def test_agent_id_filter(self): + """Test filter by agent_id only.""" + where_clause, params = _build_search_filters(None, "agent-1", None) + assert where_clause == "WHERE agent_id = ?" + assert params == ["agent-1"] + + def test_session_id_filter(self): + """Test filter by session_id only.""" + where_clause, params = _build_search_filters(None, None, "session-1") + assert where_clause == "WHERE session_id = ?" + assert params == ["session-1"] + + def test_multiple_filters(self): + """Test combining multiple filters.""" + where_clause, params = _build_search_filters("conversation", "agent-1", "session-1") + assert where_clause == "WHERE memory_type = ? AND agent_id = ? AND session_id = ?" + assert params == ["conversation", "agent-1", "session-1"] + + +@pytest.mark.unit +class TestFetchMemoryCandidates: + """Tests for _fetch_memory_candidates helper function.""" + + def test_fetch_with_data(self, patched_db): + """Test fetching candidates when data exists.""" + # Store some test data + for i in range(5): + store_memory( + content=f"Test content {i}", + source="fetch_test", + compute_embedding=False, + ) + + rows = _fetch_memory_candidates("", [], 10) + assert len(rows) == 5 + + def test_fetch_with_limit(self, patched_db): + """Test that limit is respected.""" + for i in range(10): + store_memory( + content=f"Test content {i}", + source="limit_test", + compute_embedding=False, + ) + + rows = _fetch_memory_candidates("", [], 3) + assert len(rows) == 3 + + def test_fetch_with_where_clause(self, patched_db): + """Test fetching with WHERE clause.""" + store_memory( + content="Fact content", + source="test", + context_type="fact", + compute_embedding=False, + ) + store_memory( + content="Conversation content", + source="test", + context_type="conversation", + compute_embedding=False, + ) + + where_clause, params = _build_search_filters("fact", None, None) + rows = _fetch_memory_candidates(where_clause, params, 10) + assert len(rows) == 1 + assert rows[0]["content"] == "Fact content" + + +@pytest.mark.unit +class TestRowToEntry: + """Tests for _row_to_entry conversion function.""" + + def test_convert_basic_row(self): + """Test converting a basic sqlite Row to MemoryEntry.""" + # Create mock row + row_data = { + "id": "row-1", + "content": "Row content", + "memory_type": "conversation", + "source": "test", + "agent_id": "agent-1", + "task_id": "task-1", + "session_id": "session-1", + "metadata": None, + "embedding": None, + "created_at": "2026-03-23T10:00:00", + } + + # Mock sqlite3.Row behavior + class MockRow: + def __getitem__(self, key): + return row_data.get(key) + + entry = _row_to_entry(MockRow()) + assert entry.id == "row-1" + assert entry.content == "Row content" + assert entry.context_type == "conversation" # memory_type -> context_type + assert entry.agent_id == "agent-1" + + def test_convert_with_metadata(self): + """Test converting row with JSON metadata.""" + row_data = { + "id": "row-2", + "content": "Content with metadata", + "memory_type": "fact", + "source": "test", + "agent_id": None, + "task_id": None, + "session_id": None, + "metadata": '{"key": "value", "num": 42}', + "embedding": None, + "created_at": "2026-03-23T10:00:00", + } + + class MockRow: + def __getitem__(self, key): + return row_data.get(key) + + entry = _row_to_entry(MockRow()) + assert entry.metadata == {"key": "value", "num": 42} + + def test_convert_with_embedding(self): + """Test converting row with JSON embedding.""" + row_data = { + "id": "row-3", + "content": "Content with embedding", + "memory_type": "conversation", + "source": "test", + "agent_id": None, + "task_id": None, + "session_id": None, + "metadata": None, + "embedding": "[0.1, 0.2, 0.3]", + "created_at": "2026-03-23T10:00:00", + } + + class MockRow: + def __getitem__(self, key): + return row_data.get(key) + + entry = _row_to_entry(MockRow()) + assert entry.embedding == [0.1, 0.2, 0.3] + + +@pytest.mark.unit +class TestScoreAndFilter: + """Tests for _score_and_filter function.""" + + def test_empty_rows(self): + """Test filtering empty rows list.""" + results = _score_and_filter([], "query", [0.1, 0.2], 0.5) + assert results == [] + + def test_filter_by_min_relevance(self, patched_db): + """Test filtering by minimum relevance score.""" + # Create rows with embeddings + rows = [] + for i in range(3): + entry = store_memory( + content=f"Content {i}", + source="test", + context_type="conversation", + compute_embedding=True, # Get actual embeddings + ) + # Fetch row back + with get_connection() as conn: + row = conn.execute( + "SELECT * FROM memories WHERE id = ?", + (entry.id,), + ).fetchone() + rows.append(row) + + query_embedding = [0.1] * 128 + results = _score_and_filter(rows, "query", query_embedding, 0.99) # High threshold + # Should filter out all results with high threshold + assert len(results) <= len(rows) + + def test_keyword_fallback_no_embedding(self, patched_db): + """Test keyword overlap fallback when row has no embedding.""" + # Store without embedding + entry = store_memory( + content="hello world test", + source="test", + compute_embedding=False, + ) + + with get_connection() as conn: + row = conn.execute( + "SELECT * FROM memories WHERE id = ?", + (entry.id,), + ).fetchone() + + results = _score_and_filter([row], "hello world", [0.1] * 128, 0.0) + assert len(results) == 1 + assert results[0].relevance_score is not None + + +@pytest.mark.unit +class TestSearchMemories: + """Tests for search_memories function.""" + + def test_search_empty_db(self, patched_db): + """Test searching empty database.""" + results = search_memories("anything") + assert results == [] + + def test_search_returns_results(self, patched_db): + """Test search returns matching results.""" + store_memory( + content="Python programming is fun", + source="test", + context_type="conversation", + compute_embedding=True, + ) + store_memory( + content="JavaScript is different", + source="test", + context_type="conversation", + compute_embedding=True, + ) + + results = search_memories("python programming", limit=5) + assert len(results) > 0 + + def test_search_with_filters(self, patched_db): + """Test search with context_type filter.""" + store_memory( + content="Fact about Python", + source="test", + context_type="fact", + compute_embedding=False, + ) + store_memory( + content="Conversation about Python", + source="test", + context_type="conversation", + compute_embedding=False, + ) + + results = search_memories("python", context_type="fact") + assert len(results) == 1 + assert results[0].context_type == "fact" + + def test_search_with_agent_filter(self, patched_db): + """Test search with agent_id filter.""" + store_memory( + content="Agent 1 memory", + source="test", + agent_id="agent-1", + compute_embedding=False, + ) + store_memory( + content="Agent 2 memory", + source="test", + agent_id="agent-2", + compute_embedding=False, + ) + + results = search_memories("memory", agent_id="agent-1") + assert len(results) == 1 + assert results[0].agent_id == "agent-1" + + def test_search_respects_limit(self, patched_db): + """Test that limit parameter is respected.""" + for i in range(10): + store_memory( + content=f"Memory {i}", + source="test", + compute_embedding=False, + ) + + results = search_memories("memory", limit=3) + assert len(results) <= 3 + + def test_search_with_min_relevance(self, patched_db): + """Test search with min_relevance threshold.""" + for i in range(5): + store_memory( + content=f"Unique content xyz{i}", + source="test", + compute_embedding=False, + ) + + # High threshold should return fewer results + results = search_memories("xyz", min_relevance=0.9, limit=10) + # With hash embeddings, high threshold may filter everything + assert isinstance(results, list) + + +@pytest.mark.unit +class TestDeleteMemory: + """Tests for delete_memory function.""" + + def test_delete_existing_memory(self, patched_db): + """Test deleting an existing memory.""" + entry = store_memory( + content="To be deleted", + source="test", + compute_embedding=False, + ) + + result = delete_memory(entry.id) + assert result is True + + # Verify deletion + with get_connection() as conn: + row = conn.execute( + "SELECT * FROM memories WHERE id = ?", + (entry.id,), + ).fetchone() + assert row is None + + def test_delete_nonexistent_memory(self, patched_db): + """Test deleting a non-existent memory.""" + result = delete_memory("nonexistent-id-12345") + assert result is False + + def test_delete_multiple_memories(self, patched_db): + """Test deleting multiple memories one by one.""" + entries = [] + for i in range(3): + entry = store_memory( + content=f"Delete me {i}", + source="test", + compute_embedding=False, + ) + entries.append(entry) + + for entry in entries: + result = delete_memory(entry.id) + assert result is True + + +@pytest.mark.unit +class TestGetMemoryStats: + """Tests for get_memory_stats function.""" + + def test_stats_empty_db(self, patched_db): + """Test stats on empty database.""" + stats = get_memory_stats() + assert stats["total_entries"] == 0 + assert stats["by_type"] == {} + assert stats["with_embeddings"] == 0 + + def test_stats_with_entries(self, patched_db): + """Test stats with various entries.""" + store_memory( + content="Fact 1", + source="test", + context_type="fact", + compute_embedding=False, + ) + store_memory( + content="Fact 2", + source="test", + context_type="fact", + compute_embedding=True, + ) + store_memory( + content="Conversation 1", + source="test", + context_type="conversation", + compute_embedding=False, + ) + + stats = get_memory_stats() + assert stats["total_entries"] == 3 + assert stats["by_type"]["fact"] == 2 + assert stats["by_type"]["conversation"] == 1 + assert stats["with_embeddings"] == 1 + + def test_stats_embedding_model_status(self, patched_db): + """Test that stats reports embedding model status.""" + stats = get_memory_stats() + assert "has_embedding_model" in stats + # In test mode, embeddings are skipped + assert stats["has_embedding_model"] is False + + +@pytest.mark.unit +class TestPruneMemories: + """Tests for prune_memories function.""" + + def test_prune_empty_db(self, patched_db): + """Test pruning empty database.""" + deleted = prune_memories(older_than_days=30) + assert deleted == 0 + + def test_prune_old_memories(self, patched_db): + """Test pruning old memories.""" + # Store a memory + store_memory( + content="Recent memory", + source="test", + compute_embedding=False, + ) + + # Prune shouldn't delete recent memories + deleted = prune_memories(older_than_days=30) + assert deleted == 0 + + def test_prune_keeps_facts(self, patched_db): + """Test that prune keeps fact-type memories when keep_facts=True.""" + # Insert an old fact directly + old_time = (datetime.now(UTC) - timedelta(days=100)).isoformat() + with get_connection() as conn: + conn.execute( + """ + INSERT INTO memories + (id, content, memory_type, source, created_at) + VALUES (?, ?, 'fact', 'test', ?) + """, + ("old-fact-1", "Old fact", old_time), + ) + conn.commit() + + # Prune with keep_facts=True should not delete facts + deleted = prune_memories(older_than_days=30, keep_facts=True) + assert deleted == 0 + + # Verify fact still exists + with get_connection() as conn: + row = conn.execute( + "SELECT * FROM memories WHERE id = ?", + ("old-fact-1",), + ).fetchone() + assert row is not None + + def test_prune_deletes_facts_when_keep_false(self, patched_db): + """Test that prune deletes facts when keep_facts=False.""" + # Insert an old fact directly + old_time = (datetime.now(UTC) - timedelta(days=100)).isoformat() + with get_connection() as conn: + conn.execute( + """ + INSERT INTO memories + (id, content, memory_type, source, created_at) + VALUES (?, ?, 'fact', 'test', ?) + """, + ("old-fact-2", "Old fact to delete", old_time), + ) + conn.commit() + + # Prune with keep_facts=False should delete facts + deleted = prune_memories(older_than_days=30, keep_facts=False) + assert deleted == 1 + + def test_prune_non_fact_memories(self, patched_db): + """Test pruning non-fact memories.""" + # Insert old non-fact memory + old_time = (datetime.now(UTC) - timedelta(days=100)).isoformat() + with get_connection() as conn: + conn.execute( + """ + INSERT INTO memories + (id, content, memory_type, source, created_at) + VALUES (?, ?, 'conversation', 'test', ?) + """, + ("old-conv-1", "Old conversation", old_time), + ) + conn.commit() + + deleted = prune_memories(older_than_days=30, keep_facts=True) + assert deleted == 1 + + +@pytest.mark.unit +class TestGetMemoryContext: + """Tests for get_memory_context function.""" + + def test_empty_context(self, patched_db): + """Test getting context from empty database.""" + context = get_memory_context("query") + assert context == "" + + def test_context_with_results(self, patched_db): + """Test getting context with matching results.""" + store_memory( + content="Python is a programming language", + source="user", + compute_embedding=False, + ) + + context = get_memory_context("python programming") + # May or may not match depending on search + assert isinstance(context, str) + + def test_context_respects_max_tokens(self, patched_db): + """Test that max_tokens limits context size.""" + for i in range(20): + store_memory( + content=f"This is memory number {i} with some content", + source="test", + compute_embedding=False, + ) + + context = get_memory_context("memory", max_tokens=100) + # Rough approximation: 100 tokens * 4 chars = 400 chars + assert len(context) <= 500 or context == "" + + def test_context_formatting(self, patched_db): + """Test that context is properly formatted.""" + store_memory( + content="Important information", + source="system", + compute_embedding=False, + ) + + context = get_memory_context("important") + if context: + assert "Relevant context from memory:" in context or context == "" + + +@pytest.mark.unit +class TestPersonalFacts: + """Tests for personal facts functions.""" + + def test_recall_personal_facts_empty(self, patched_db): + """Test recalling facts when none exist.""" + facts = recall_personal_facts() + assert facts == [] + + def test_store_and_recall_personal_fact(self, patched_db): + """Test storing and recalling a personal fact.""" + entry = store_personal_fact("User likes Python", agent_id="agent-1") + + assert entry.context_type == "fact" + assert entry.content == "User likes Python" + assert entry.agent_id == "agent-1" + + facts = recall_personal_facts() + assert "User likes Python" in facts + + def test_recall_personal_facts_with_agent_filter(self, patched_db): + """Test recalling facts filtered by agent_id.""" + store_personal_fact("Fact for agent 1", agent_id="agent-1") + store_personal_fact("Fact for agent 2", agent_id="agent-2") + + facts = recall_personal_facts(agent_id="agent-1") + assert len(facts) == 1 + assert "Fact for agent 1" in facts + + def test_recall_personal_facts_with_ids(self, patched_db): + """Test recalling facts with their IDs.""" + entry = store_personal_fact("Fact with ID", agent_id="agent-1") + + facts_with_ids = recall_personal_facts_with_ids() + assert len(facts_with_ids) == 1 + assert facts_with_ids[0]["id"] == entry.id + assert facts_with_ids[0]["content"] == "Fact with ID" + + def test_update_personal_fact(self, patched_db): + """Test updating a personal fact.""" + entry = store_personal_fact("Original fact", agent_id="agent-1") + + result = update_personal_fact(entry.id, "Updated fact") + assert result is True + + facts = recall_personal_facts() + assert "Updated fact" in facts + assert "Original fact" not in facts + + def test_update_nonexistent_fact(self, patched_db): + """Test updating a non-existent fact.""" + result = update_personal_fact("nonexistent-id", "New content") + assert result is False + + def test_update_only_affects_facts(self, patched_db): + """Test that update only affects fact-type memories.""" + # Store a non-fact memory + entry = store_memory( + content="Not a fact", + source="test", + context_type="conversation", + compute_embedding=False, + ) + + # Try to update it as if it were a fact + result = update_personal_fact(entry.id, "Updated content") + assert result is False + + +@pytest.mark.unit +class TestReflections: + """Tests for reflection storage and recall.""" + + def test_store_and_recall_reflection(self, patched_db): + """Test storing and recalling a reflection.""" + store_last_reflection("This is my reflection") + + result = recall_last_reflection() + assert result == "This is my reflection" + + def test_reflection_replaces_previous(self, patched_db): + """Test that storing reflection replaces the previous one.""" + store_last_reflection("First reflection") + store_last_reflection("Second reflection") + + result = recall_last_reflection() + assert result == "Second reflection" + + # Verify only one reflection in DB + with get_connection() as conn: + count = conn.execute( + "SELECT COUNT(*) FROM memories WHERE memory_type = 'reflection'" + ).fetchone()[0] + assert count == 1 + + def test_store_empty_reflection(self, patched_db): + """Test that empty reflection is not stored.""" + store_last_reflection("") + store_last_reflection(" ") + store_last_reflection(None) + + result = recall_last_reflection() + assert result is None + + def test_recall_no_reflection(self, patched_db): + """Test recalling when no reflection exists.""" + result = recall_last_reflection() + assert result is None + + def test_reflection_strips_whitespace(self, patched_db): + """Test that reflection content is stripped.""" + store_last_reflection(" Reflection with whitespace ") + + result = recall_last_reflection() + assert result == "Reflection with whitespace" + + +@pytest.mark.unit +class TestEdgeCases: + """Edge cases and error handling.""" + + def test_unicode_content(self, patched_db): + """Test handling of unicode content.""" + entry = store_memory( + content="Unicode: 你好世界 🎉 café naïve", + source="test", + compute_embedding=False, + ) + + assert entry.content == "Unicode: 你好世界 🎉 café naïve" + + # Verify in DB + with get_connection() as conn: + row = conn.execute( + "SELECT content FROM memories WHERE id = ?", + (entry.id,), + ).fetchone() + assert "你好世界" in row["content"] + + def test_special_characters_in_content(self, patched_db): + """Test handling of special characters.""" + content = """ + SQL: SELECT * FROM users + JSON: {"key": "value"} + Escapes: \\n \\t""" + + entry = store_memory( + content=content, + source="test", + compute_embedding=False, + ) + + assert entry.content == content + + def test_very_long_content(self, patched_db): + """Test handling of very long content.""" + long_content = "Word " * 1000 + + entry = store_memory( + content=long_content, + source="test", + compute_embedding=False, + ) + + assert len(entry.content) == len(long_content) + + def test_metadata_with_nested_structure(self, patched_db): + """Test storing metadata with nested structure.""" + metadata = { + "level1": { + "level2": { + "level3": ["item1", "item2"] + } + }, + "number": 42, + "boolean": True, + "null": None, + } + + entry = store_memory( + content="Nested metadata test", + source="test", + metadata=metadata, + compute_embedding=False, + ) + + # Verify metadata round-trips correctly + with get_connection() as conn: + row = conn.execute( + "SELECT metadata FROM memories WHERE id = ?", + (entry.id,), + ).fetchone() + + loaded = json.loads(row["metadata"]) + assert loaded["level1"]["level2"]["level3"] == ["item1", "item2"] + assert loaded["number"] == 42 + assert loaded["boolean"] is True + + def test_duplicate_keys_not_prevented(self, patched_db): + """Test that duplicate content is allowed.""" + entry1 = store_memory( + content="Duplicate content", + source="test", + compute_embedding=False, + ) + entry2 = store_memory( + content="Duplicate content", + source="test", + compute_embedding=False, + ) + + assert entry1.id != entry2.id + + stats = get_memory_stats() + assert stats["total_entries"] == 2