"""Comprehensive tests for spark.memory module. Covers: - SparkEvent / SparkMemory dataclasses - _get_conn (schema creation, WAL, busy timeout, idempotent indexes) - score_importance (all event types, boosts, edge cases) - record_event (auto-importance, explicit importance, invalid JSON, swarm bridge) - get_events (all filters, ordering, limit) - count_events (total, by type) - store_memory (with/without expiry) - get_memories (all filters) - count_memories (total, by type) """ import json import pytest from spark.memory import ( IMPORTANCE_HIGH, IMPORTANCE_LOW, IMPORTANCE_MEDIUM, SparkEvent, SparkMemory, _get_conn, count_events, count_memories, get_events, get_memories, record_event, score_importance, store_memory, ) # ── Constants ───────────────────────────────────────────────────────────── class TestConstants: def test_importance_ordering(self): assert IMPORTANCE_LOW < IMPORTANCE_MEDIUM < IMPORTANCE_HIGH # ── Dataclasses ─────────────────────────────────────────────────────────── class TestSparkEventDataclass: def test_all_fields(self): ev = SparkEvent( id="1", event_type="task_posted", agent_id="a1", task_id="t1", description="Test", data="{}", importance=0.5, created_at="2026-01-01", ) assert ev.event_type == "task_posted" assert ev.agent_id == "a1" def test_nullable_fields(self): ev = SparkEvent( id="2", event_type="task_posted", agent_id=None, task_id=None, description="", data="{}", importance=0.5, created_at="2026-01-01", ) assert ev.agent_id is None assert ev.task_id is None class TestSparkMemoryDataclass: def test_all_fields(self): mem = SparkMemory( id="1", memory_type="pattern", subject="system", content="Test insight", confidence=0.8, source_events=5, created_at="2026-01-01", expires_at="2026-12-31", ) assert mem.memory_type == "pattern" assert mem.expires_at == "2026-12-31" def test_nullable_expires(self): mem = SparkMemory( id="2", memory_type="anomaly", subject="agent-1", content="Odd behavior", confidence=0.6, source_events=3, created_at="2026-01-01", expires_at=None, ) assert mem.expires_at is None # ── _get_conn ───────────────────────────────────────────────────────────── class TestGetConn: def test_creates_tables(self): with _get_conn() as conn: tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() names = {r["name"] for r in tables} assert "spark_events" in names assert "spark_memories" in names def test_wal_mode(self): with _get_conn() as conn: mode = conn.execute("PRAGMA journal_mode").fetchone()[0] assert mode == "wal" def test_busy_timeout(self): with _get_conn() as conn: timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0] assert timeout == 5000 def test_idempotent(self): # Calling _get_conn twice should not raise with _get_conn(): pass with _get_conn(): pass # ── score_importance ────────────────────────────────────────────────────── class TestScoreImportance: @pytest.mark.parametrize( "event_type,expected_min,expected_max", [ ("task_posted", 0.3, 0.5), ("bid_submitted", 0.1, 0.3), ("task_assigned", 0.4, 0.6), ("task_completed", 0.5, 0.7), ("task_failed", 0.9, 1.0), ("agent_joined", 0.4, 0.6), ("prediction_result", 0.6, 0.8), ], ) def test_base_scores(self, event_type, expected_min, expected_max): score = score_importance(event_type, {}) assert expected_min <= score <= expected_max def test_unknown_event_default(self): assert score_importance("never_heard_of_this", {}) == 0.5 def test_failure_boost(self): score = score_importance("task_failed", {}) assert score == 1.0 def test_high_bid_boost(self): low = score_importance("bid_submitted", {"bid_sats": 10}) high = score_importance("bid_submitted", {"bid_sats": 100}) assert high > low assert high <= 1.0 def test_high_bid_on_failure(self): score = score_importance("task_failed", {"bid_sats": 100}) assert score == 1.0 # capped at 1.0 def test_score_always_rounded(self): score = score_importance("bid_submitted", {"bid_sats": 100}) assert score == round(score, 2) # ── record_event ────────────────────────────────────────────────────────── class TestRecordEvent: def test_basic_record(self): eid = record_event("task_posted", "New task", task_id="t1") assert isinstance(eid, str) assert len(eid) > 0 def test_auto_importance(self): record_event("task_failed", "Failed", task_id="t-auto") events = get_events(task_id="t-auto") assert events[0].importance >= 0.9 def test_explicit_importance(self): record_event("task_posted", "Custom", task_id="t-expl", importance=0.1) events = get_events(task_id="t-expl") assert events[0].importance == 0.1 def test_with_agent_and_data(self): data = json.dumps({"bid_sats": 42}) record_event("bid_submitted", "Bid", agent_id="a1", task_id="t-data", data=data) events = get_events(task_id="t-data") assert events[0].agent_id == "a1" parsed = json.loads(events[0].data) assert parsed["bid_sats"] == 42 def test_invalid_json_data_uses_default_importance(self): record_event("task_posted", "Bad data", task_id="t-bad", data="not-json") events = get_events(task_id="t-bad") assert events[0].importance == 0.4 # base for task_posted def test_returns_unique_ids(self): id1 = record_event("task_posted", "A") id2 = record_event("task_posted", "B") assert id1 != id2 # ── get_events ──────────────────────────────────────────────────────────── class TestGetEvents: def test_empty_db(self): assert get_events() == [] def test_filter_by_type(self): record_event("task_posted", "A") record_event("task_completed", "B") events = get_events(event_type="task_posted") assert len(events) == 1 assert events[0].event_type == "task_posted" def test_filter_by_agent(self): record_event("task_posted", "A", agent_id="a1") record_event("task_posted", "B", agent_id="a2") events = get_events(agent_id="a1") assert len(events) == 1 assert events[0].agent_id == "a1" def test_filter_by_task(self): record_event("task_posted", "A", task_id="t1") record_event("task_posted", "B", task_id="t2") events = get_events(task_id="t1") assert len(events) == 1 def test_filter_by_min_importance(self): record_event("task_posted", "Low", importance=0.1) record_event("task_failed", "High", importance=0.9) events = get_events(min_importance=0.5) assert len(events) == 1 assert events[0].importance >= 0.5 def test_limit(self): for i in range(10): record_event("task_posted", f"ev{i}") events = get_events(limit=3) assert len(events) == 3 def test_order_by_created_desc(self): record_event("task_posted", "first", task_id="ord1") record_event("task_posted", "second", task_id="ord2") events = get_events() # Most recent first assert events[0].task_id == "ord2" def test_combined_filters(self): record_event("task_failed", "A", agent_id="a1", task_id="t1", importance=0.9) record_event("task_posted", "B", agent_id="a1", task_id="t2", importance=0.4) record_event("task_failed", "C", agent_id="a2", task_id="t3", importance=0.9) events = get_events(event_type="task_failed", agent_id="a1", min_importance=0.5) assert len(events) == 1 assert events[0].task_id == "t1" # ── count_events ────────────────────────────────────────────────────────── class TestCountEvents: def test_empty(self): assert count_events() == 0 def test_total(self): record_event("task_posted", "A") record_event("task_failed", "B") assert count_events() == 2 def test_by_type(self): record_event("task_posted", "A") record_event("task_posted", "B") record_event("task_failed", "C") assert count_events("task_posted") == 2 assert count_events("task_failed") == 1 assert count_events("task_completed") == 0 # ── store_memory ────────────────────────────────────────────────────────── class TestStoreMemory: def test_basic_store(self): mid = store_memory("pattern", "system", "Test insight") assert isinstance(mid, str) assert len(mid) > 0 def test_returns_unique_ids(self): id1 = store_memory("pattern", "a", "X") id2 = store_memory("pattern", "b", "Y") assert id1 != id2 def test_with_all_params(self): store_memory( "anomaly", "agent-1", "Odd pattern", confidence=0.9, source_events=10, expires_at="2026-12-31", ) mems = get_memories(subject="agent-1") assert len(mems) == 1 assert mems[0].confidence == 0.9 assert mems[0].source_events == 10 assert mems[0].expires_at == "2026-12-31" def test_default_values(self): store_memory("insight", "sys", "Default test") mems = get_memories(subject="sys") assert mems[0].confidence == 0.5 assert mems[0].source_events == 0 assert mems[0].expires_at is None # ── get_memories ────────────────────────────────────────────────────────── class TestGetMemories: def test_empty(self): assert get_memories() == [] def test_filter_by_type(self): store_memory("pattern", "a", "X") store_memory("anomaly", "a", "Y") mems = get_memories(memory_type="pattern") assert len(mems) == 1 assert mems[0].memory_type == "pattern" def test_filter_by_subject(self): store_memory("pattern", "a", "X") store_memory("pattern", "b", "Y") mems = get_memories(subject="a") assert len(mems) == 1 def test_filter_by_min_confidence(self): store_memory("pattern", "a", "Low", confidence=0.2) store_memory("pattern", "b", "High", confidence=0.9) mems = get_memories(min_confidence=0.5) assert len(mems) == 1 assert mems[0].content == "High" def test_limit(self): for i in range(10): store_memory("pattern", "a", f"M{i}") mems = get_memories(limit=3) assert len(mems) == 3 def test_combined_filters(self): store_memory("pattern", "a", "Target", confidence=0.9) store_memory("anomaly", "a", "Wrong type", confidence=0.9) store_memory("pattern", "b", "Wrong subject", confidence=0.9) store_memory("pattern", "a", "Low conf", confidence=0.1) mems = get_memories(memory_type="pattern", subject="a", min_confidence=0.5) assert len(mems) == 1 assert mems[0].content == "Target" # ── count_memories ──────────────────────────────────────────────────────── class TestCountMemories: def test_empty(self): assert count_memories() == 0 def test_total(self): store_memory("pattern", "a", "X") store_memory("anomaly", "b", "Y") assert count_memories() == 2 def test_by_type(self): store_memory("pattern", "a", "X") store_memory("pattern", "b", "Y") store_memory("anomaly", "c", "Z") assert count_memories("pattern") == 2 assert count_memories("anomaly") == 1 assert count_memories("insight") == 0