[kimi] Add unit tests for memory/crud.py (#1344) (#1358)
Some checks failed
Tests / lint (push) Failing after 30s
Tests / test (push) Has been skipped

This commit was merged in pull request #1358.
This commit is contained in:
2026-03-24 03:08:36 +00:00
parent 9e9dd5309a
commit 91d06eeb49

View File

@@ -0,0 +1,889 @@
"""Unit tests for timmy.memory.crud — Memory CRUD operations."""
import json
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
import pytest
from timmy.memory.crud import (
_build_search_filters,
_fetch_memory_candidates,
_row_to_entry,
_score_and_filter,
delete_memory,
get_memory_context,
get_memory_stats,
prune_memories,
recall_last_reflection,
recall_personal_facts,
recall_personal_facts_with_ids,
search_memories,
store_last_reflection,
store_memory,
store_personal_fact,
update_personal_fact,
)
from timmy.memory.db import MemoryEntry, get_connection
@pytest.fixture
def tmp_db_path(tmp_path):
"""Provide a temporary database path."""
return tmp_path / "test_memory.db"
@pytest.fixture
def patched_db(tmp_db_path):
"""Patch DB_PATH to use temporary database for tests."""
with patch("timmy.memory.db.DB_PATH", tmp_db_path):
# Create schema
with get_connection():
pass # Schema is created by get_connection
yield tmp_db_path
@pytest.fixture
def sample_entry():
"""Create a sample MemoryEntry for testing."""
return MemoryEntry(
id="test-id-123",
content="Test memory content",
source="test",
context_type="conversation",
agent_id="agent-1",
task_id="task-1",
session_id="session-1",
metadata={"key": "value"},
embedding=[0.1, 0.2, 0.3],
timestamp=datetime.now(UTC).isoformat(),
)
@pytest.mark.unit
class TestStoreMemory:
"""Tests for store_memory function."""
def test_store_memory_basic(self, patched_db):
"""Test storing a basic memory entry."""
entry = store_memory(
content="Hello world",
source="test",
context_type="conversation",
)
assert isinstance(entry, MemoryEntry)
assert entry.content == "Hello world"
assert entry.source == "test"
assert entry.context_type == "conversation"
assert entry.id is not None
def test_store_memory_with_all_fields(self, patched_db):
"""Test storing memory with all optional fields."""
entry = store_memory(
content="Full test content",
source="user",
context_type="fact",
agent_id="agent-42",
task_id="task-99",
session_id="session-xyz",
metadata={"priority": "high", "tags": ["test"]},
compute_embedding=False,
)
assert entry.content == "Full test content"
assert entry.agent_id == "agent-42"
assert entry.task_id == "task-99"
assert entry.session_id == "session-xyz"
assert entry.metadata == {"priority": "high", "tags": ["test"]}
assert entry.embedding is None
def test_store_memory_with_embedding(self, patched_db):
"""Test storing memory with embedding computation."""
# TIMMY_SKIP_EMBEDDINGS=1 in conftest, so uses hash fallback
entry = store_memory(
content="Test content for embedding",
source="system",
compute_embedding=True,
)
assert entry.embedding is not None
assert isinstance(entry.embedding, list)
assert len(entry.embedding) == 128 # Hash embedding dimension
def test_store_memory_without_embedding(self, patched_db):
"""Test storing memory without embedding."""
entry = store_memory(
content="No embedding needed",
source="test",
compute_embedding=False,
)
assert entry.embedding is None
def test_store_memory_persists_to_db(self, patched_db):
"""Test that stored memory is actually written to database."""
entry = store_memory(
content="Persisted content",
source="db_test",
context_type="document",
)
# Verify directly in DB
with get_connection() as conn:
row = conn.execute(
"SELECT * FROM memories WHERE id = ?",
(entry.id,),
).fetchone()
assert row is not None
assert row["content"] == "Persisted content"
assert row["memory_type"] == "document"
assert row["source"] == "db_test"
@pytest.mark.unit
class TestBuildSearchFilters:
"""Tests for _build_search_filters helper function."""
def test_no_filters(self):
"""Test building filters with no criteria."""
where_clause, params = _build_search_filters(None, None, None)
assert where_clause == ""
assert params == []
def test_context_type_filter(self):
"""Test filter by context_type only."""
where_clause, params = _build_search_filters("fact", None, None)
assert where_clause == "WHERE memory_type = ?"
assert params == ["fact"]
def test_agent_id_filter(self):
"""Test filter by agent_id only."""
where_clause, params = _build_search_filters(None, "agent-1", None)
assert where_clause == "WHERE agent_id = ?"
assert params == ["agent-1"]
def test_session_id_filter(self):
"""Test filter by session_id only."""
where_clause, params = _build_search_filters(None, None, "session-1")
assert where_clause == "WHERE session_id = ?"
assert params == ["session-1"]
def test_multiple_filters(self):
"""Test combining multiple filters."""
where_clause, params = _build_search_filters("conversation", "agent-1", "session-1")
assert where_clause == "WHERE memory_type = ? AND agent_id = ? AND session_id = ?"
assert params == ["conversation", "agent-1", "session-1"]
@pytest.mark.unit
class TestFetchMemoryCandidates:
"""Tests for _fetch_memory_candidates helper function."""
def test_fetch_with_data(self, patched_db):
"""Test fetching candidates when data exists."""
# Store some test data
for i in range(5):
store_memory(
content=f"Test content {i}",
source="fetch_test",
compute_embedding=False,
)
rows = _fetch_memory_candidates("", [], 10)
assert len(rows) == 5
def test_fetch_with_limit(self, patched_db):
"""Test that limit is respected."""
for i in range(10):
store_memory(
content=f"Test content {i}",
source="limit_test",
compute_embedding=False,
)
rows = _fetch_memory_candidates("", [], 3)
assert len(rows) == 3
def test_fetch_with_where_clause(self, patched_db):
"""Test fetching with WHERE clause."""
store_memory(
content="Fact content",
source="test",
context_type="fact",
compute_embedding=False,
)
store_memory(
content="Conversation content",
source="test",
context_type="conversation",
compute_embedding=False,
)
where_clause, params = _build_search_filters("fact", None, None)
rows = _fetch_memory_candidates(where_clause, params, 10)
assert len(rows) == 1
assert rows[0]["content"] == "Fact content"
@pytest.mark.unit
class TestRowToEntry:
"""Tests for _row_to_entry conversion function."""
def test_convert_basic_row(self):
"""Test converting a basic sqlite Row to MemoryEntry."""
# Create mock row
row_data = {
"id": "row-1",
"content": "Row content",
"memory_type": "conversation",
"source": "test",
"agent_id": "agent-1",
"task_id": "task-1",
"session_id": "session-1",
"metadata": None,
"embedding": None,
"created_at": "2026-03-23T10:00:00",
}
# Mock sqlite3.Row behavior
class MockRow:
def __getitem__(self, key):
return row_data.get(key)
entry = _row_to_entry(MockRow())
assert entry.id == "row-1"
assert entry.content == "Row content"
assert entry.context_type == "conversation" # memory_type -> context_type
assert entry.agent_id == "agent-1"
def test_convert_with_metadata(self):
"""Test converting row with JSON metadata."""
row_data = {
"id": "row-2",
"content": "Content with metadata",
"memory_type": "fact",
"source": "test",
"agent_id": None,
"task_id": None,
"session_id": None,
"metadata": '{"key": "value", "num": 42}',
"embedding": None,
"created_at": "2026-03-23T10:00:00",
}
class MockRow:
def __getitem__(self, key):
return row_data.get(key)
entry = _row_to_entry(MockRow())
assert entry.metadata == {"key": "value", "num": 42}
def test_convert_with_embedding(self):
"""Test converting row with JSON embedding."""
row_data = {
"id": "row-3",
"content": "Content with embedding",
"memory_type": "conversation",
"source": "test",
"agent_id": None,
"task_id": None,
"session_id": None,
"metadata": None,
"embedding": "[0.1, 0.2, 0.3]",
"created_at": "2026-03-23T10:00:00",
}
class MockRow:
def __getitem__(self, key):
return row_data.get(key)
entry = _row_to_entry(MockRow())
assert entry.embedding == [0.1, 0.2, 0.3]
@pytest.mark.unit
class TestScoreAndFilter:
"""Tests for _score_and_filter function."""
def test_empty_rows(self):
"""Test filtering empty rows list."""
results = _score_and_filter([], "query", [0.1, 0.2], 0.5)
assert results == []
def test_filter_by_min_relevance(self, patched_db):
"""Test filtering by minimum relevance score."""
# Create rows with embeddings
rows = []
for i in range(3):
entry = store_memory(
content=f"Content {i}",
source="test",
context_type="conversation",
compute_embedding=True, # Get actual embeddings
)
# Fetch row back
with get_connection() as conn:
row = conn.execute(
"SELECT * FROM memories WHERE id = ?",
(entry.id,),
).fetchone()
rows.append(row)
query_embedding = [0.1] * 128
results = _score_and_filter(rows, "query", query_embedding, 0.99) # High threshold
# Should filter out all results with high threshold
assert len(results) <= len(rows)
def test_keyword_fallback_no_embedding(self, patched_db):
"""Test keyword overlap fallback when row has no embedding."""
# Store without embedding
entry = store_memory(
content="hello world test",
source="test",
compute_embedding=False,
)
with get_connection() as conn:
row = conn.execute(
"SELECT * FROM memories WHERE id = ?",
(entry.id,),
).fetchone()
results = _score_and_filter([row], "hello world", [0.1] * 128, 0.0)
assert len(results) == 1
assert results[0].relevance_score is not None
@pytest.mark.unit
class TestSearchMemories:
"""Tests for search_memories function."""
def test_search_empty_db(self, patched_db):
"""Test searching empty database."""
results = search_memories("anything")
assert results == []
def test_search_returns_results(self, patched_db):
"""Test search returns matching results."""
store_memory(
content="Python programming is fun",
source="test",
context_type="conversation",
compute_embedding=True,
)
store_memory(
content="JavaScript is different",
source="test",
context_type="conversation",
compute_embedding=True,
)
results = search_memories("python programming", limit=5)
assert len(results) > 0
def test_search_with_filters(self, patched_db):
"""Test search with context_type filter."""
store_memory(
content="Fact about Python",
source="test",
context_type="fact",
compute_embedding=False,
)
store_memory(
content="Conversation about Python",
source="test",
context_type="conversation",
compute_embedding=False,
)
results = search_memories("python", context_type="fact")
assert len(results) == 1
assert results[0].context_type == "fact"
def test_search_with_agent_filter(self, patched_db):
"""Test search with agent_id filter."""
store_memory(
content="Agent 1 memory",
source="test",
agent_id="agent-1",
compute_embedding=False,
)
store_memory(
content="Agent 2 memory",
source="test",
agent_id="agent-2",
compute_embedding=False,
)
results = search_memories("memory", agent_id="agent-1")
assert len(results) == 1
assert results[0].agent_id == "agent-1"
def test_search_respects_limit(self, patched_db):
"""Test that limit parameter is respected."""
for i in range(10):
store_memory(
content=f"Memory {i}",
source="test",
compute_embedding=False,
)
results = search_memories("memory", limit=3)
assert len(results) <= 3
def test_search_with_min_relevance(self, patched_db):
"""Test search with min_relevance threshold."""
for i in range(5):
store_memory(
content=f"Unique content xyz{i}",
source="test",
compute_embedding=False,
)
# High threshold should return fewer results
results = search_memories("xyz", min_relevance=0.9, limit=10)
# With hash embeddings, high threshold may filter everything
assert isinstance(results, list)
@pytest.mark.unit
class TestDeleteMemory:
"""Tests for delete_memory function."""
def test_delete_existing_memory(self, patched_db):
"""Test deleting an existing memory."""
entry = store_memory(
content="To be deleted",
source="test",
compute_embedding=False,
)
result = delete_memory(entry.id)
assert result is True
# Verify deletion
with get_connection() as conn:
row = conn.execute(
"SELECT * FROM memories WHERE id = ?",
(entry.id,),
).fetchone()
assert row is None
def test_delete_nonexistent_memory(self, patched_db):
"""Test deleting a non-existent memory."""
result = delete_memory("nonexistent-id-12345")
assert result is False
def test_delete_multiple_memories(self, patched_db):
"""Test deleting multiple memories one by one."""
entries = []
for i in range(3):
entry = store_memory(
content=f"Delete me {i}",
source="test",
compute_embedding=False,
)
entries.append(entry)
for entry in entries:
result = delete_memory(entry.id)
assert result is True
@pytest.mark.unit
class TestGetMemoryStats:
"""Tests for get_memory_stats function."""
def test_stats_empty_db(self, patched_db):
"""Test stats on empty database."""
stats = get_memory_stats()
assert stats["total_entries"] == 0
assert stats["by_type"] == {}
assert stats["with_embeddings"] == 0
def test_stats_with_entries(self, patched_db):
"""Test stats with various entries."""
store_memory(
content="Fact 1",
source="test",
context_type="fact",
compute_embedding=False,
)
store_memory(
content="Fact 2",
source="test",
context_type="fact",
compute_embedding=True,
)
store_memory(
content="Conversation 1",
source="test",
context_type="conversation",
compute_embedding=False,
)
stats = get_memory_stats()
assert stats["total_entries"] == 3
assert stats["by_type"]["fact"] == 2
assert stats["by_type"]["conversation"] == 1
assert stats["with_embeddings"] == 1
def test_stats_embedding_model_status(self, patched_db):
"""Test that stats reports embedding model status."""
stats = get_memory_stats()
assert "has_embedding_model" in stats
# In test mode, embeddings are skipped
assert stats["has_embedding_model"] is False
@pytest.mark.unit
class TestPruneMemories:
"""Tests for prune_memories function."""
def test_prune_empty_db(self, patched_db):
"""Test pruning empty database."""
deleted = prune_memories(older_than_days=30)
assert deleted == 0
def test_prune_old_memories(self, patched_db):
"""Test pruning old memories."""
# Store a memory
store_memory(
content="Recent memory",
source="test",
compute_embedding=False,
)
# Prune shouldn't delete recent memories
deleted = prune_memories(older_than_days=30)
assert deleted == 0
def test_prune_keeps_facts(self, patched_db):
"""Test that prune keeps fact-type memories when keep_facts=True."""
# Insert an old fact directly
old_time = (datetime.now(UTC) - timedelta(days=100)).isoformat()
with get_connection() as conn:
conn.execute(
"""
INSERT INTO memories
(id, content, memory_type, source, created_at)
VALUES (?, ?, 'fact', 'test', ?)
""",
("old-fact-1", "Old fact", old_time),
)
conn.commit()
# Prune with keep_facts=True should not delete facts
deleted = prune_memories(older_than_days=30, keep_facts=True)
assert deleted == 0
# Verify fact still exists
with get_connection() as conn:
row = conn.execute(
"SELECT * FROM memories WHERE id = ?",
("old-fact-1",),
).fetchone()
assert row is not None
def test_prune_deletes_facts_when_keep_false(self, patched_db):
"""Test that prune deletes facts when keep_facts=False."""
# Insert an old fact directly
old_time = (datetime.now(UTC) - timedelta(days=100)).isoformat()
with get_connection() as conn:
conn.execute(
"""
INSERT INTO memories
(id, content, memory_type, source, created_at)
VALUES (?, ?, 'fact', 'test', ?)
""",
("old-fact-2", "Old fact to delete", old_time),
)
conn.commit()
# Prune with keep_facts=False should delete facts
deleted = prune_memories(older_than_days=30, keep_facts=False)
assert deleted == 1
def test_prune_non_fact_memories(self, patched_db):
"""Test pruning non-fact memories."""
# Insert old non-fact memory
old_time = (datetime.now(UTC) - timedelta(days=100)).isoformat()
with get_connection() as conn:
conn.execute(
"""
INSERT INTO memories
(id, content, memory_type, source, created_at)
VALUES (?, ?, 'conversation', 'test', ?)
""",
("old-conv-1", "Old conversation", old_time),
)
conn.commit()
deleted = prune_memories(older_than_days=30, keep_facts=True)
assert deleted == 1
@pytest.mark.unit
class TestGetMemoryContext:
"""Tests for get_memory_context function."""
def test_empty_context(self, patched_db):
"""Test getting context from empty database."""
context = get_memory_context("query")
assert context == ""
def test_context_with_results(self, patched_db):
"""Test getting context with matching results."""
store_memory(
content="Python is a programming language",
source="user",
compute_embedding=False,
)
context = get_memory_context("python programming")
# May or may not match depending on search
assert isinstance(context, str)
def test_context_respects_max_tokens(self, patched_db):
"""Test that max_tokens limits context size."""
for i in range(20):
store_memory(
content=f"This is memory number {i} with some content",
source="test",
compute_embedding=False,
)
context = get_memory_context("memory", max_tokens=100)
# Rough approximation: 100 tokens * 4 chars = 400 chars
assert len(context) <= 500 or context == ""
def test_context_formatting(self, patched_db):
"""Test that context is properly formatted."""
store_memory(
content="Important information",
source="system",
compute_embedding=False,
)
context = get_memory_context("important")
if context:
assert "Relevant context from memory:" in context or context == ""
@pytest.mark.unit
class TestPersonalFacts:
"""Tests for personal facts functions."""
def test_recall_personal_facts_empty(self, patched_db):
"""Test recalling facts when none exist."""
facts = recall_personal_facts()
assert facts == []
def test_store_and_recall_personal_fact(self, patched_db):
"""Test storing and recalling a personal fact."""
entry = store_personal_fact("User likes Python", agent_id="agent-1")
assert entry.context_type == "fact"
assert entry.content == "User likes Python"
assert entry.agent_id == "agent-1"
facts = recall_personal_facts()
assert "User likes Python" in facts
def test_recall_personal_facts_with_agent_filter(self, patched_db):
"""Test recalling facts filtered by agent_id."""
store_personal_fact("Fact for agent 1", agent_id="agent-1")
store_personal_fact("Fact for agent 2", agent_id="agent-2")
facts = recall_personal_facts(agent_id="agent-1")
assert len(facts) == 1
assert "Fact for agent 1" in facts
def test_recall_personal_facts_with_ids(self, patched_db):
"""Test recalling facts with their IDs."""
entry = store_personal_fact("Fact with ID", agent_id="agent-1")
facts_with_ids = recall_personal_facts_with_ids()
assert len(facts_with_ids) == 1
assert facts_with_ids[0]["id"] == entry.id
assert facts_with_ids[0]["content"] == "Fact with ID"
def test_update_personal_fact(self, patched_db):
"""Test updating a personal fact."""
entry = store_personal_fact("Original fact", agent_id="agent-1")
result = update_personal_fact(entry.id, "Updated fact")
assert result is True
facts = recall_personal_facts()
assert "Updated fact" in facts
assert "Original fact" not in facts
def test_update_nonexistent_fact(self, patched_db):
"""Test updating a non-existent fact."""
result = update_personal_fact("nonexistent-id", "New content")
assert result is False
def test_update_only_affects_facts(self, patched_db):
"""Test that update only affects fact-type memories."""
# Store a non-fact memory
entry = store_memory(
content="Not a fact",
source="test",
context_type="conversation",
compute_embedding=False,
)
# Try to update it as if it were a fact
result = update_personal_fact(entry.id, "Updated content")
assert result is False
@pytest.mark.unit
class TestReflections:
"""Tests for reflection storage and recall."""
def test_store_and_recall_reflection(self, patched_db):
"""Test storing and recalling a reflection."""
store_last_reflection("This is my reflection")
result = recall_last_reflection()
assert result == "This is my reflection"
def test_reflection_replaces_previous(self, patched_db):
"""Test that storing reflection replaces the previous one."""
store_last_reflection("First reflection")
store_last_reflection("Second reflection")
result = recall_last_reflection()
assert result == "Second reflection"
# Verify only one reflection in DB
with get_connection() as conn:
count = conn.execute(
"SELECT COUNT(*) FROM memories WHERE memory_type = 'reflection'"
).fetchone()[0]
assert count == 1
def test_store_empty_reflection(self, patched_db):
"""Test that empty reflection is not stored."""
store_last_reflection("")
store_last_reflection(" ")
store_last_reflection(None)
result = recall_last_reflection()
assert result is None
def test_recall_no_reflection(self, patched_db):
"""Test recalling when no reflection exists."""
result = recall_last_reflection()
assert result is None
def test_reflection_strips_whitespace(self, patched_db):
"""Test that reflection content is stripped."""
store_last_reflection(" Reflection with whitespace ")
result = recall_last_reflection()
assert result == "Reflection with whitespace"
@pytest.mark.unit
class TestEdgeCases:
"""Edge cases and error handling."""
def test_unicode_content(self, patched_db):
"""Test handling of unicode content."""
entry = store_memory(
content="Unicode: 你好世界 🎉 café naïve",
source="test",
compute_embedding=False,
)
assert entry.content == "Unicode: 你好世界 🎉 café naïve"
# Verify in DB
with get_connection() as conn:
row = conn.execute(
"SELECT content FROM memories WHERE id = ?",
(entry.id,),
).fetchone()
assert "你好世界" in row["content"]
def test_special_characters_in_content(self, patched_db):
"""Test handling of special characters."""
content = """<script>alert('xss')</script>
SQL: SELECT * FROM users
JSON: {"key": "value"}
Escapes: \\n \\t"""
entry = store_memory(
content=content,
source="test",
compute_embedding=False,
)
assert entry.content == content
def test_very_long_content(self, patched_db):
"""Test handling of very long content."""
long_content = "Word " * 1000
entry = store_memory(
content=long_content,
source="test",
compute_embedding=False,
)
assert len(entry.content) == len(long_content)
def test_metadata_with_nested_structure(self, patched_db):
"""Test storing metadata with nested structure."""
metadata = {
"level1": {
"level2": {
"level3": ["item1", "item2"]
}
},
"number": 42,
"boolean": True,
"null": None,
}
entry = store_memory(
content="Nested metadata test",
source="test",
metadata=metadata,
compute_embedding=False,
)
# Verify metadata round-trips correctly
with get_connection() as conn:
row = conn.execute(
"SELECT metadata FROM memories WHERE id = ?",
(entry.id,),
).fetchone()
loaded = json.loads(row["metadata"])
assert loaded["level1"]["level2"]["level3"] == ["item1", "item2"]
assert loaded["number"] == 42
assert loaded["boolean"] is True
def test_duplicate_keys_not_prevented(self, patched_db):
"""Test that duplicate content is allowed."""
entry1 = store_memory(
content="Duplicate content",
source="test",
compute_embedding=False,
)
entry2 = store_memory(
content="Duplicate content",
source="test",
compute_embedding=False,
)
assert entry1.id != entry2.id
stats = get_memory_stats()
assert stats["total_entries"] == 2