This commit was merged in pull request #1358.
This commit is contained in:
889
tests/timmy/test_memory_crud.py
Normal file
889
tests/timmy/test_memory_crud.py
Normal file
@@ -0,0 +1,889 @@
|
||||
"""Unit tests for timmy.memory.crud — Memory CRUD operations."""
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from timmy.memory.crud import (
|
||||
_build_search_filters,
|
||||
_fetch_memory_candidates,
|
||||
_row_to_entry,
|
||||
_score_and_filter,
|
||||
delete_memory,
|
||||
get_memory_context,
|
||||
get_memory_stats,
|
||||
prune_memories,
|
||||
recall_last_reflection,
|
||||
recall_personal_facts,
|
||||
recall_personal_facts_with_ids,
|
||||
search_memories,
|
||||
store_last_reflection,
|
||||
store_memory,
|
||||
store_personal_fact,
|
||||
update_personal_fact,
|
||||
)
|
||||
from timmy.memory.db import MemoryEntry, get_connection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_db_path(tmp_path):
|
||||
"""Provide a temporary database path."""
|
||||
return tmp_path / "test_memory.db"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_db(tmp_db_path):
|
||||
"""Patch DB_PATH to use temporary database for tests."""
|
||||
with patch("timmy.memory.db.DB_PATH", tmp_db_path):
|
||||
# Create schema
|
||||
with get_connection():
|
||||
pass # Schema is created by get_connection
|
||||
yield tmp_db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_entry():
|
||||
"""Create a sample MemoryEntry for testing."""
|
||||
return MemoryEntry(
|
||||
id="test-id-123",
|
||||
content="Test memory content",
|
||||
source="test",
|
||||
context_type="conversation",
|
||||
agent_id="agent-1",
|
||||
task_id="task-1",
|
||||
session_id="session-1",
|
||||
metadata={"key": "value"},
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStoreMemory:
|
||||
"""Tests for store_memory function."""
|
||||
|
||||
def test_store_memory_basic(self, patched_db):
|
||||
"""Test storing a basic memory entry."""
|
||||
entry = store_memory(
|
||||
content="Hello world",
|
||||
source="test",
|
||||
context_type="conversation",
|
||||
)
|
||||
|
||||
assert isinstance(entry, MemoryEntry)
|
||||
assert entry.content == "Hello world"
|
||||
assert entry.source == "test"
|
||||
assert entry.context_type == "conversation"
|
||||
assert entry.id is not None
|
||||
|
||||
def test_store_memory_with_all_fields(self, patched_db):
|
||||
"""Test storing memory with all optional fields."""
|
||||
entry = store_memory(
|
||||
content="Full test content",
|
||||
source="user",
|
||||
context_type="fact",
|
||||
agent_id="agent-42",
|
||||
task_id="task-99",
|
||||
session_id="session-xyz",
|
||||
metadata={"priority": "high", "tags": ["test"]},
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
assert entry.content == "Full test content"
|
||||
assert entry.agent_id == "agent-42"
|
||||
assert entry.task_id == "task-99"
|
||||
assert entry.session_id == "session-xyz"
|
||||
assert entry.metadata == {"priority": "high", "tags": ["test"]}
|
||||
assert entry.embedding is None
|
||||
|
||||
def test_store_memory_with_embedding(self, patched_db):
|
||||
"""Test storing memory with embedding computation."""
|
||||
# TIMMY_SKIP_EMBEDDINGS=1 in conftest, so uses hash fallback
|
||||
entry = store_memory(
|
||||
content="Test content for embedding",
|
||||
source="system",
|
||||
compute_embedding=True,
|
||||
)
|
||||
|
||||
assert entry.embedding is not None
|
||||
assert isinstance(entry.embedding, list)
|
||||
assert len(entry.embedding) == 128 # Hash embedding dimension
|
||||
|
||||
def test_store_memory_without_embedding(self, patched_db):
|
||||
"""Test storing memory without embedding."""
|
||||
entry = store_memory(
|
||||
content="No embedding needed",
|
||||
source="test",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
assert entry.embedding is None
|
||||
|
||||
def test_store_memory_persists_to_db(self, patched_db):
|
||||
"""Test that stored memory is actually written to database."""
|
||||
entry = store_memory(
|
||||
content="Persisted content",
|
||||
source="db_test",
|
||||
context_type="document",
|
||||
)
|
||||
|
||||
# Verify directly in DB
|
||||
with get_connection() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM memories WHERE id = ?",
|
||||
(entry.id,),
|
||||
).fetchone()
|
||||
|
||||
assert row is not None
|
||||
assert row["content"] == "Persisted content"
|
||||
assert row["memory_type"] == "document"
|
||||
assert row["source"] == "db_test"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBuildSearchFilters:
|
||||
"""Tests for _build_search_filters helper function."""
|
||||
|
||||
def test_no_filters(self):
|
||||
"""Test building filters with no criteria."""
|
||||
where_clause, params = _build_search_filters(None, None, None)
|
||||
assert where_clause == ""
|
||||
assert params == []
|
||||
|
||||
def test_context_type_filter(self):
|
||||
"""Test filter by context_type only."""
|
||||
where_clause, params = _build_search_filters("fact", None, None)
|
||||
assert where_clause == "WHERE memory_type = ?"
|
||||
assert params == ["fact"]
|
||||
|
||||
def test_agent_id_filter(self):
|
||||
"""Test filter by agent_id only."""
|
||||
where_clause, params = _build_search_filters(None, "agent-1", None)
|
||||
assert where_clause == "WHERE agent_id = ?"
|
||||
assert params == ["agent-1"]
|
||||
|
||||
def test_session_id_filter(self):
|
||||
"""Test filter by session_id only."""
|
||||
where_clause, params = _build_search_filters(None, None, "session-1")
|
||||
assert where_clause == "WHERE session_id = ?"
|
||||
assert params == ["session-1"]
|
||||
|
||||
def test_multiple_filters(self):
|
||||
"""Test combining multiple filters."""
|
||||
where_clause, params = _build_search_filters("conversation", "agent-1", "session-1")
|
||||
assert where_clause == "WHERE memory_type = ? AND agent_id = ? AND session_id = ?"
|
||||
assert params == ["conversation", "agent-1", "session-1"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFetchMemoryCandidates:
|
||||
"""Tests for _fetch_memory_candidates helper function."""
|
||||
|
||||
def test_fetch_with_data(self, patched_db):
|
||||
"""Test fetching candidates when data exists."""
|
||||
# Store some test data
|
||||
for i in range(5):
|
||||
store_memory(
|
||||
content=f"Test content {i}",
|
||||
source="fetch_test",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
rows = _fetch_memory_candidates("", [], 10)
|
||||
assert len(rows) == 5
|
||||
|
||||
def test_fetch_with_limit(self, patched_db):
|
||||
"""Test that limit is respected."""
|
||||
for i in range(10):
|
||||
store_memory(
|
||||
content=f"Test content {i}",
|
||||
source="limit_test",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
rows = _fetch_memory_candidates("", [], 3)
|
||||
assert len(rows) == 3
|
||||
|
||||
def test_fetch_with_where_clause(self, patched_db):
|
||||
"""Test fetching with WHERE clause."""
|
||||
store_memory(
|
||||
content="Fact content",
|
||||
source="test",
|
||||
context_type="fact",
|
||||
compute_embedding=False,
|
||||
)
|
||||
store_memory(
|
||||
content="Conversation content",
|
||||
source="test",
|
||||
context_type="conversation",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
where_clause, params = _build_search_filters("fact", None, None)
|
||||
rows = _fetch_memory_candidates(where_clause, params, 10)
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["content"] == "Fact content"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRowToEntry:
|
||||
"""Tests for _row_to_entry conversion function."""
|
||||
|
||||
def test_convert_basic_row(self):
|
||||
"""Test converting a basic sqlite Row to MemoryEntry."""
|
||||
# Create mock row
|
||||
row_data = {
|
||||
"id": "row-1",
|
||||
"content": "Row content",
|
||||
"memory_type": "conversation",
|
||||
"source": "test",
|
||||
"agent_id": "agent-1",
|
||||
"task_id": "task-1",
|
||||
"session_id": "session-1",
|
||||
"metadata": None,
|
||||
"embedding": None,
|
||||
"created_at": "2026-03-23T10:00:00",
|
||||
}
|
||||
|
||||
# Mock sqlite3.Row behavior
|
||||
class MockRow:
|
||||
def __getitem__(self, key):
|
||||
return row_data.get(key)
|
||||
|
||||
entry = _row_to_entry(MockRow())
|
||||
assert entry.id == "row-1"
|
||||
assert entry.content == "Row content"
|
||||
assert entry.context_type == "conversation" # memory_type -> context_type
|
||||
assert entry.agent_id == "agent-1"
|
||||
|
||||
def test_convert_with_metadata(self):
|
||||
"""Test converting row with JSON metadata."""
|
||||
row_data = {
|
||||
"id": "row-2",
|
||||
"content": "Content with metadata",
|
||||
"memory_type": "fact",
|
||||
"source": "test",
|
||||
"agent_id": None,
|
||||
"task_id": None,
|
||||
"session_id": None,
|
||||
"metadata": '{"key": "value", "num": 42}',
|
||||
"embedding": None,
|
||||
"created_at": "2026-03-23T10:00:00",
|
||||
}
|
||||
|
||||
class MockRow:
|
||||
def __getitem__(self, key):
|
||||
return row_data.get(key)
|
||||
|
||||
entry = _row_to_entry(MockRow())
|
||||
assert entry.metadata == {"key": "value", "num": 42}
|
||||
|
||||
def test_convert_with_embedding(self):
|
||||
"""Test converting row with JSON embedding."""
|
||||
row_data = {
|
||||
"id": "row-3",
|
||||
"content": "Content with embedding",
|
||||
"memory_type": "conversation",
|
||||
"source": "test",
|
||||
"agent_id": None,
|
||||
"task_id": None,
|
||||
"session_id": None,
|
||||
"metadata": None,
|
||||
"embedding": "[0.1, 0.2, 0.3]",
|
||||
"created_at": "2026-03-23T10:00:00",
|
||||
}
|
||||
|
||||
class MockRow:
|
||||
def __getitem__(self, key):
|
||||
return row_data.get(key)
|
||||
|
||||
entry = _row_to_entry(MockRow())
|
||||
assert entry.embedding == [0.1, 0.2, 0.3]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestScoreAndFilter:
|
||||
"""Tests for _score_and_filter function."""
|
||||
|
||||
def test_empty_rows(self):
|
||||
"""Test filtering empty rows list."""
|
||||
results = _score_and_filter([], "query", [0.1, 0.2], 0.5)
|
||||
assert results == []
|
||||
|
||||
def test_filter_by_min_relevance(self, patched_db):
|
||||
"""Test filtering by minimum relevance score."""
|
||||
# Create rows with embeddings
|
||||
rows = []
|
||||
for i in range(3):
|
||||
entry = store_memory(
|
||||
content=f"Content {i}",
|
||||
source="test",
|
||||
context_type="conversation",
|
||||
compute_embedding=True, # Get actual embeddings
|
||||
)
|
||||
# Fetch row back
|
||||
with get_connection() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM memories WHERE id = ?",
|
||||
(entry.id,),
|
||||
).fetchone()
|
||||
rows.append(row)
|
||||
|
||||
query_embedding = [0.1] * 128
|
||||
results = _score_and_filter(rows, "query", query_embedding, 0.99) # High threshold
|
||||
# Should filter out all results with high threshold
|
||||
assert len(results) <= len(rows)
|
||||
|
||||
def test_keyword_fallback_no_embedding(self, patched_db):
|
||||
"""Test keyword overlap fallback when row has no embedding."""
|
||||
# Store without embedding
|
||||
entry = store_memory(
|
||||
content="hello world test",
|
||||
source="test",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
with get_connection() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM memories WHERE id = ?",
|
||||
(entry.id,),
|
||||
).fetchone()
|
||||
|
||||
results = _score_and_filter([row], "hello world", [0.1] * 128, 0.0)
|
||||
assert len(results) == 1
|
||||
assert results[0].relevance_score is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSearchMemories:
|
||||
"""Tests for search_memories function."""
|
||||
|
||||
def test_search_empty_db(self, patched_db):
|
||||
"""Test searching empty database."""
|
||||
results = search_memories("anything")
|
||||
assert results == []
|
||||
|
||||
def test_search_returns_results(self, patched_db):
|
||||
"""Test search returns matching results."""
|
||||
store_memory(
|
||||
content="Python programming is fun",
|
||||
source="test",
|
||||
context_type="conversation",
|
||||
compute_embedding=True,
|
||||
)
|
||||
store_memory(
|
||||
content="JavaScript is different",
|
||||
source="test",
|
||||
context_type="conversation",
|
||||
compute_embedding=True,
|
||||
)
|
||||
|
||||
results = search_memories("python programming", limit=5)
|
||||
assert len(results) > 0
|
||||
|
||||
def test_search_with_filters(self, patched_db):
|
||||
"""Test search with context_type filter."""
|
||||
store_memory(
|
||||
content="Fact about Python",
|
||||
source="test",
|
||||
context_type="fact",
|
||||
compute_embedding=False,
|
||||
)
|
||||
store_memory(
|
||||
content="Conversation about Python",
|
||||
source="test",
|
||||
context_type="conversation",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
results = search_memories("python", context_type="fact")
|
||||
assert len(results) == 1
|
||||
assert results[0].context_type == "fact"
|
||||
|
||||
def test_search_with_agent_filter(self, patched_db):
|
||||
"""Test search with agent_id filter."""
|
||||
store_memory(
|
||||
content="Agent 1 memory",
|
||||
source="test",
|
||||
agent_id="agent-1",
|
||||
compute_embedding=False,
|
||||
)
|
||||
store_memory(
|
||||
content="Agent 2 memory",
|
||||
source="test",
|
||||
agent_id="agent-2",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
results = search_memories("memory", agent_id="agent-1")
|
||||
assert len(results) == 1
|
||||
assert results[0].agent_id == "agent-1"
|
||||
|
||||
def test_search_respects_limit(self, patched_db):
|
||||
"""Test that limit parameter is respected."""
|
||||
for i in range(10):
|
||||
store_memory(
|
||||
content=f"Memory {i}",
|
||||
source="test",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
results = search_memories("memory", limit=3)
|
||||
assert len(results) <= 3
|
||||
|
||||
def test_search_with_min_relevance(self, patched_db):
|
||||
"""Test search with min_relevance threshold."""
|
||||
for i in range(5):
|
||||
store_memory(
|
||||
content=f"Unique content xyz{i}",
|
||||
source="test",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
# High threshold should return fewer results
|
||||
results = search_memories("xyz", min_relevance=0.9, limit=10)
|
||||
# With hash embeddings, high threshold may filter everything
|
||||
assert isinstance(results, list)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDeleteMemory:
|
||||
"""Tests for delete_memory function."""
|
||||
|
||||
def test_delete_existing_memory(self, patched_db):
|
||||
"""Test deleting an existing memory."""
|
||||
entry = store_memory(
|
||||
content="To be deleted",
|
||||
source="test",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
result = delete_memory(entry.id)
|
||||
assert result is True
|
||||
|
||||
# Verify deletion
|
||||
with get_connection() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM memories WHERE id = ?",
|
||||
(entry.id,),
|
||||
).fetchone()
|
||||
assert row is None
|
||||
|
||||
def test_delete_nonexistent_memory(self, patched_db):
|
||||
"""Test deleting a non-existent memory."""
|
||||
result = delete_memory("nonexistent-id-12345")
|
||||
assert result is False
|
||||
|
||||
def test_delete_multiple_memories(self, patched_db):
|
||||
"""Test deleting multiple memories one by one."""
|
||||
entries = []
|
||||
for i in range(3):
|
||||
entry = store_memory(
|
||||
content=f"Delete me {i}",
|
||||
source="test",
|
||||
compute_embedding=False,
|
||||
)
|
||||
entries.append(entry)
|
||||
|
||||
for entry in entries:
|
||||
result = delete_memory(entry.id)
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetMemoryStats:
|
||||
"""Tests for get_memory_stats function."""
|
||||
|
||||
def test_stats_empty_db(self, patched_db):
|
||||
"""Test stats on empty database."""
|
||||
stats = get_memory_stats()
|
||||
assert stats["total_entries"] == 0
|
||||
assert stats["by_type"] == {}
|
||||
assert stats["with_embeddings"] == 0
|
||||
|
||||
def test_stats_with_entries(self, patched_db):
|
||||
"""Test stats with various entries."""
|
||||
store_memory(
|
||||
content="Fact 1",
|
||||
source="test",
|
||||
context_type="fact",
|
||||
compute_embedding=False,
|
||||
)
|
||||
store_memory(
|
||||
content="Fact 2",
|
||||
source="test",
|
||||
context_type="fact",
|
||||
compute_embedding=True,
|
||||
)
|
||||
store_memory(
|
||||
content="Conversation 1",
|
||||
source="test",
|
||||
context_type="conversation",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
stats = get_memory_stats()
|
||||
assert stats["total_entries"] == 3
|
||||
assert stats["by_type"]["fact"] == 2
|
||||
assert stats["by_type"]["conversation"] == 1
|
||||
assert stats["with_embeddings"] == 1
|
||||
|
||||
def test_stats_embedding_model_status(self, patched_db):
|
||||
"""Test that stats reports embedding model status."""
|
||||
stats = get_memory_stats()
|
||||
assert "has_embedding_model" in stats
|
||||
# In test mode, embeddings are skipped
|
||||
assert stats["has_embedding_model"] is False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPruneMemories:
|
||||
"""Tests for prune_memories function."""
|
||||
|
||||
def test_prune_empty_db(self, patched_db):
|
||||
"""Test pruning empty database."""
|
||||
deleted = prune_memories(older_than_days=30)
|
||||
assert deleted == 0
|
||||
|
||||
def test_prune_old_memories(self, patched_db):
|
||||
"""Test pruning old memories."""
|
||||
# Store a memory
|
||||
store_memory(
|
||||
content="Recent memory",
|
||||
source="test",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
# Prune shouldn't delete recent memories
|
||||
deleted = prune_memories(older_than_days=30)
|
||||
assert deleted == 0
|
||||
|
||||
def test_prune_keeps_facts(self, patched_db):
|
||||
"""Test that prune keeps fact-type memories when keep_facts=True."""
|
||||
# Insert an old fact directly
|
||||
old_time = (datetime.now(UTC) - timedelta(days=100)).isoformat()
|
||||
with get_connection() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO memories
|
||||
(id, content, memory_type, source, created_at)
|
||||
VALUES (?, ?, 'fact', 'test', ?)
|
||||
""",
|
||||
("old-fact-1", "Old fact", old_time),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# Prune with keep_facts=True should not delete facts
|
||||
deleted = prune_memories(older_than_days=30, keep_facts=True)
|
||||
assert deleted == 0
|
||||
|
||||
# Verify fact still exists
|
||||
with get_connection() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM memories WHERE id = ?",
|
||||
("old-fact-1",),
|
||||
).fetchone()
|
||||
assert row is not None
|
||||
|
||||
def test_prune_deletes_facts_when_keep_false(self, patched_db):
|
||||
"""Test that prune deletes facts when keep_facts=False."""
|
||||
# Insert an old fact directly
|
||||
old_time = (datetime.now(UTC) - timedelta(days=100)).isoformat()
|
||||
with get_connection() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO memories
|
||||
(id, content, memory_type, source, created_at)
|
||||
VALUES (?, ?, 'fact', 'test', ?)
|
||||
""",
|
||||
("old-fact-2", "Old fact to delete", old_time),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# Prune with keep_facts=False should delete facts
|
||||
deleted = prune_memories(older_than_days=30, keep_facts=False)
|
||||
assert deleted == 1
|
||||
|
||||
def test_prune_non_fact_memories(self, patched_db):
|
||||
"""Test pruning non-fact memories."""
|
||||
# Insert old non-fact memory
|
||||
old_time = (datetime.now(UTC) - timedelta(days=100)).isoformat()
|
||||
with get_connection() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO memories
|
||||
(id, content, memory_type, source, created_at)
|
||||
VALUES (?, ?, 'conversation', 'test', ?)
|
||||
""",
|
||||
("old-conv-1", "Old conversation", old_time),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
deleted = prune_memories(older_than_days=30, keep_facts=True)
|
||||
assert deleted == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetMemoryContext:
|
||||
"""Tests for get_memory_context function."""
|
||||
|
||||
def test_empty_context(self, patched_db):
|
||||
"""Test getting context from empty database."""
|
||||
context = get_memory_context("query")
|
||||
assert context == ""
|
||||
|
||||
def test_context_with_results(self, patched_db):
|
||||
"""Test getting context with matching results."""
|
||||
store_memory(
|
||||
content="Python is a programming language",
|
||||
source="user",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
context = get_memory_context("python programming")
|
||||
# May or may not match depending on search
|
||||
assert isinstance(context, str)
|
||||
|
||||
def test_context_respects_max_tokens(self, patched_db):
|
||||
"""Test that max_tokens limits context size."""
|
||||
for i in range(20):
|
||||
store_memory(
|
||||
content=f"This is memory number {i} with some content",
|
||||
source="test",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
context = get_memory_context("memory", max_tokens=100)
|
||||
# Rough approximation: 100 tokens * 4 chars = 400 chars
|
||||
assert len(context) <= 500 or context == ""
|
||||
|
||||
def test_context_formatting(self, patched_db):
|
||||
"""Test that context is properly formatted."""
|
||||
store_memory(
|
||||
content="Important information",
|
||||
source="system",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
context = get_memory_context("important")
|
||||
if context:
|
||||
assert "Relevant context from memory:" in context or context == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPersonalFacts:
|
||||
"""Tests for personal facts functions."""
|
||||
|
||||
def test_recall_personal_facts_empty(self, patched_db):
|
||||
"""Test recalling facts when none exist."""
|
||||
facts = recall_personal_facts()
|
||||
assert facts == []
|
||||
|
||||
def test_store_and_recall_personal_fact(self, patched_db):
|
||||
"""Test storing and recalling a personal fact."""
|
||||
entry = store_personal_fact("User likes Python", agent_id="agent-1")
|
||||
|
||||
assert entry.context_type == "fact"
|
||||
assert entry.content == "User likes Python"
|
||||
assert entry.agent_id == "agent-1"
|
||||
|
||||
facts = recall_personal_facts()
|
||||
assert "User likes Python" in facts
|
||||
|
||||
def test_recall_personal_facts_with_agent_filter(self, patched_db):
|
||||
"""Test recalling facts filtered by agent_id."""
|
||||
store_personal_fact("Fact for agent 1", agent_id="agent-1")
|
||||
store_personal_fact("Fact for agent 2", agent_id="agent-2")
|
||||
|
||||
facts = recall_personal_facts(agent_id="agent-1")
|
||||
assert len(facts) == 1
|
||||
assert "Fact for agent 1" in facts
|
||||
|
||||
def test_recall_personal_facts_with_ids(self, patched_db):
|
||||
"""Test recalling facts with their IDs."""
|
||||
entry = store_personal_fact("Fact with ID", agent_id="agent-1")
|
||||
|
||||
facts_with_ids = recall_personal_facts_with_ids()
|
||||
assert len(facts_with_ids) == 1
|
||||
assert facts_with_ids[0]["id"] == entry.id
|
||||
assert facts_with_ids[0]["content"] == "Fact with ID"
|
||||
|
||||
def test_update_personal_fact(self, patched_db):
|
||||
"""Test updating a personal fact."""
|
||||
entry = store_personal_fact("Original fact", agent_id="agent-1")
|
||||
|
||||
result = update_personal_fact(entry.id, "Updated fact")
|
||||
assert result is True
|
||||
|
||||
facts = recall_personal_facts()
|
||||
assert "Updated fact" in facts
|
||||
assert "Original fact" not in facts
|
||||
|
||||
def test_update_nonexistent_fact(self, patched_db):
|
||||
"""Test updating a non-existent fact."""
|
||||
result = update_personal_fact("nonexistent-id", "New content")
|
||||
assert result is False
|
||||
|
||||
def test_update_only_affects_facts(self, patched_db):
|
||||
"""Test that update only affects fact-type memories."""
|
||||
# Store a non-fact memory
|
||||
entry = store_memory(
|
||||
content="Not a fact",
|
||||
source="test",
|
||||
context_type="conversation",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
# Try to update it as if it were a fact
|
||||
result = update_personal_fact(entry.id, "Updated content")
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestReflections:
|
||||
"""Tests for reflection storage and recall."""
|
||||
|
||||
def test_store_and_recall_reflection(self, patched_db):
|
||||
"""Test storing and recalling a reflection."""
|
||||
store_last_reflection("This is my reflection")
|
||||
|
||||
result = recall_last_reflection()
|
||||
assert result == "This is my reflection"
|
||||
|
||||
def test_reflection_replaces_previous(self, patched_db):
|
||||
"""Test that storing reflection replaces the previous one."""
|
||||
store_last_reflection("First reflection")
|
||||
store_last_reflection("Second reflection")
|
||||
|
||||
result = recall_last_reflection()
|
||||
assert result == "Second reflection"
|
||||
|
||||
# Verify only one reflection in DB
|
||||
with get_connection() as conn:
|
||||
count = conn.execute(
|
||||
"SELECT COUNT(*) FROM memories WHERE memory_type = 'reflection'"
|
||||
).fetchone()[0]
|
||||
assert count == 1
|
||||
|
||||
def test_store_empty_reflection(self, patched_db):
|
||||
"""Test that empty reflection is not stored."""
|
||||
store_last_reflection("")
|
||||
store_last_reflection(" ")
|
||||
store_last_reflection(None)
|
||||
|
||||
result = recall_last_reflection()
|
||||
assert result is None
|
||||
|
||||
def test_recall_no_reflection(self, patched_db):
|
||||
"""Test recalling when no reflection exists."""
|
||||
result = recall_last_reflection()
|
||||
assert result is None
|
||||
|
||||
def test_reflection_strips_whitespace(self, patched_db):
|
||||
"""Test that reflection content is stripped."""
|
||||
store_last_reflection(" Reflection with whitespace ")
|
||||
|
||||
result = recall_last_reflection()
|
||||
assert result == "Reflection with whitespace"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestEdgeCases:
|
||||
"""Edge cases and error handling."""
|
||||
|
||||
def test_unicode_content(self, patched_db):
|
||||
"""Test handling of unicode content."""
|
||||
entry = store_memory(
|
||||
content="Unicode: 你好世界 🎉 café naïve",
|
||||
source="test",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
assert entry.content == "Unicode: 你好世界 🎉 café naïve"
|
||||
|
||||
# Verify in DB
|
||||
with get_connection() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT content FROM memories WHERE id = ?",
|
||||
(entry.id,),
|
||||
).fetchone()
|
||||
assert "你好世界" in row["content"]
|
||||
|
||||
def test_special_characters_in_content(self, patched_db):
|
||||
"""Test handling of special characters."""
|
||||
content = """<script>alert('xss')</script>
|
||||
SQL: SELECT * FROM users
|
||||
JSON: {"key": "value"}
|
||||
Escapes: \\n \\t"""
|
||||
|
||||
entry = store_memory(
|
||||
content=content,
|
||||
source="test",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
assert entry.content == content
|
||||
|
||||
def test_very_long_content(self, patched_db):
|
||||
"""Test handling of very long content."""
|
||||
long_content = "Word " * 1000
|
||||
|
||||
entry = store_memory(
|
||||
content=long_content,
|
||||
source="test",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
assert len(entry.content) == len(long_content)
|
||||
|
||||
def test_metadata_with_nested_structure(self, patched_db):
|
||||
"""Test storing metadata with nested structure."""
|
||||
metadata = {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": ["item1", "item2"]
|
||||
}
|
||||
},
|
||||
"number": 42,
|
||||
"boolean": True,
|
||||
"null": None,
|
||||
}
|
||||
|
||||
entry = store_memory(
|
||||
content="Nested metadata test",
|
||||
source="test",
|
||||
metadata=metadata,
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
# Verify metadata round-trips correctly
|
||||
with get_connection() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT metadata FROM memories WHERE id = ?",
|
||||
(entry.id,),
|
||||
).fetchone()
|
||||
|
||||
loaded = json.loads(row["metadata"])
|
||||
assert loaded["level1"]["level2"]["level3"] == ["item1", "item2"]
|
||||
assert loaded["number"] == 42
|
||||
assert loaded["boolean"] is True
|
||||
|
||||
def test_duplicate_keys_not_prevented(self, patched_db):
|
||||
"""Test that duplicate content is allowed."""
|
||||
entry1 = store_memory(
|
||||
content="Duplicate content",
|
||||
source="test",
|
||||
compute_embedding=False,
|
||||
)
|
||||
entry2 = store_memory(
|
||||
content="Duplicate content",
|
||||
source="test",
|
||||
compute_embedding=False,
|
||||
)
|
||||
|
||||
assert entry1.id != entry2.id
|
||||
|
||||
stats = get_memory_stats()
|
||||
assert stats["total_entries"] == 2
|
||||
Reference in New Issue
Block a user