390 lines
13 KiB
Python
390 lines
13 KiB
Python
"""Comprehensive tests for spark.memory module.
|
|
|
|
Covers:
|
|
- SparkEvent / SparkMemory dataclasses
|
|
- _get_conn (schema creation, WAL, busy timeout, idempotent indexes)
|
|
- score_importance (all event types, boosts, edge cases)
|
|
- record_event (auto-importance, explicit importance, invalid JSON, swarm bridge)
|
|
- get_events (all filters, ordering, limit)
|
|
- count_events (total, by type)
|
|
- store_memory (with/without expiry)
|
|
- get_memories (all filters)
|
|
- count_memories (total, by type)
|
|
"""
|
|
|
|
import json
|
|
|
|
import pytest
|
|
|
|
from spark.memory import (
|
|
IMPORTANCE_HIGH,
|
|
IMPORTANCE_LOW,
|
|
IMPORTANCE_MEDIUM,
|
|
SparkEvent,
|
|
SparkMemory,
|
|
_get_conn,
|
|
count_events,
|
|
count_memories,
|
|
get_events,
|
|
get_memories,
|
|
record_event,
|
|
score_importance,
|
|
store_memory,
|
|
)
|
|
|
|
# ── Constants ─────────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestConstants:
|
|
def test_importance_ordering(self):
|
|
assert IMPORTANCE_LOW < IMPORTANCE_MEDIUM < IMPORTANCE_HIGH
|
|
|
|
|
|
# ── Dataclasses ───────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestSparkEventDataclass:
|
|
def test_all_fields(self):
|
|
ev = SparkEvent(
|
|
id="1",
|
|
event_type="task_posted",
|
|
agent_id="a1",
|
|
task_id="t1",
|
|
description="Test",
|
|
data="{}",
|
|
importance=0.5,
|
|
created_at="2026-01-01",
|
|
)
|
|
assert ev.event_type == "task_posted"
|
|
assert ev.agent_id == "a1"
|
|
|
|
def test_nullable_fields(self):
|
|
ev = SparkEvent(
|
|
id="2",
|
|
event_type="task_posted",
|
|
agent_id=None,
|
|
task_id=None,
|
|
description="",
|
|
data="{}",
|
|
importance=0.5,
|
|
created_at="2026-01-01",
|
|
)
|
|
assert ev.agent_id is None
|
|
assert ev.task_id is None
|
|
|
|
|
|
class TestSparkMemoryDataclass:
|
|
def test_all_fields(self):
|
|
mem = SparkMemory(
|
|
id="1",
|
|
memory_type="pattern",
|
|
subject="system",
|
|
content="Test insight",
|
|
confidence=0.8,
|
|
source_events=5,
|
|
created_at="2026-01-01",
|
|
expires_at="2026-12-31",
|
|
)
|
|
assert mem.memory_type == "pattern"
|
|
assert mem.expires_at == "2026-12-31"
|
|
|
|
def test_nullable_expires(self):
|
|
mem = SparkMemory(
|
|
id="2",
|
|
memory_type="anomaly",
|
|
subject="agent-1",
|
|
content="Odd behavior",
|
|
confidence=0.6,
|
|
source_events=3,
|
|
created_at="2026-01-01",
|
|
expires_at=None,
|
|
)
|
|
assert mem.expires_at is None
|
|
|
|
|
|
# ── _get_conn ─────────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestGetConn:
|
|
def test_creates_tables(self):
|
|
with _get_conn() as conn:
|
|
tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
|
|
names = {r["name"] for r in tables}
|
|
assert "spark_events" in names
|
|
assert "spark_memories" in names
|
|
|
|
def test_wal_mode(self):
|
|
with _get_conn() as conn:
|
|
mode = conn.execute("PRAGMA journal_mode").fetchone()[0]
|
|
assert mode == "wal"
|
|
|
|
def test_busy_timeout(self):
|
|
with _get_conn() as conn:
|
|
timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0]
|
|
assert timeout == 5000
|
|
|
|
def test_idempotent(self):
|
|
# Calling _get_conn twice should not raise
|
|
with _get_conn():
|
|
pass
|
|
with _get_conn():
|
|
pass
|
|
|
|
|
|
# ── score_importance ──────────────────────────────────────────────────────
|
|
|
|
|
|
class TestScoreImportance:
|
|
@pytest.mark.parametrize(
|
|
"event_type,expected_min,expected_max",
|
|
[
|
|
("task_posted", 0.3, 0.5),
|
|
("bid_submitted", 0.1, 0.3),
|
|
("task_assigned", 0.4, 0.6),
|
|
("task_completed", 0.5, 0.7),
|
|
("task_failed", 0.9, 1.0),
|
|
("agent_joined", 0.4, 0.6),
|
|
("prediction_result", 0.6, 0.8),
|
|
],
|
|
)
|
|
def test_base_scores(self, event_type, expected_min, expected_max):
|
|
score = score_importance(event_type, {})
|
|
assert expected_min <= score <= expected_max
|
|
|
|
def test_unknown_event_default(self):
|
|
assert score_importance("never_heard_of_this", {}) == 0.5
|
|
|
|
def test_failure_boost(self):
|
|
score = score_importance("task_failed", {})
|
|
assert score == 1.0
|
|
|
|
def test_high_bid_boost(self):
|
|
low = score_importance("bid_submitted", {"bid_sats": 10})
|
|
high = score_importance("bid_submitted", {"bid_sats": 100})
|
|
assert high > low
|
|
assert high <= 1.0
|
|
|
|
def test_high_bid_on_failure(self):
|
|
score = score_importance("task_failed", {"bid_sats": 100})
|
|
assert score == 1.0 # capped at 1.0
|
|
|
|
def test_score_always_rounded(self):
|
|
score = score_importance("bid_submitted", {"bid_sats": 100})
|
|
assert score == round(score, 2)
|
|
|
|
|
|
# ── record_event ──────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestRecordEvent:
|
|
def test_basic_record(self):
|
|
eid = record_event("task_posted", "New task", task_id="t1")
|
|
assert isinstance(eid, str)
|
|
assert len(eid) > 0
|
|
|
|
def test_auto_importance(self):
|
|
record_event("task_failed", "Failed", task_id="t-auto")
|
|
events = get_events(task_id="t-auto")
|
|
assert events[0].importance >= 0.9
|
|
|
|
def test_explicit_importance(self):
|
|
record_event("task_posted", "Custom", task_id="t-expl", importance=0.1)
|
|
events = get_events(task_id="t-expl")
|
|
assert events[0].importance == 0.1
|
|
|
|
def test_with_agent_and_data(self):
|
|
data = json.dumps({"bid_sats": 42})
|
|
record_event("bid_submitted", "Bid", agent_id="a1", task_id="t-data", data=data)
|
|
events = get_events(task_id="t-data")
|
|
assert events[0].agent_id == "a1"
|
|
parsed = json.loads(events[0].data)
|
|
assert parsed["bid_sats"] == 42
|
|
|
|
def test_invalid_json_data_uses_default_importance(self):
|
|
record_event("task_posted", "Bad data", task_id="t-bad", data="not-json")
|
|
events = get_events(task_id="t-bad")
|
|
assert events[0].importance == 0.4 # base for task_posted
|
|
|
|
def test_returns_unique_ids(self):
|
|
id1 = record_event("task_posted", "A")
|
|
id2 = record_event("task_posted", "B")
|
|
assert id1 != id2
|
|
|
|
|
|
# ── get_events ────────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestGetEvents:
|
|
def test_empty_db(self):
|
|
assert get_events() == []
|
|
|
|
def test_filter_by_type(self):
|
|
record_event("task_posted", "A")
|
|
record_event("task_completed", "B")
|
|
events = get_events(event_type="task_posted")
|
|
assert len(events) == 1
|
|
assert events[0].event_type == "task_posted"
|
|
|
|
def test_filter_by_agent(self):
|
|
record_event("task_posted", "A", agent_id="a1")
|
|
record_event("task_posted", "B", agent_id="a2")
|
|
events = get_events(agent_id="a1")
|
|
assert len(events) == 1
|
|
assert events[0].agent_id == "a1"
|
|
|
|
def test_filter_by_task(self):
|
|
record_event("task_posted", "A", task_id="t1")
|
|
record_event("task_posted", "B", task_id="t2")
|
|
events = get_events(task_id="t1")
|
|
assert len(events) == 1
|
|
|
|
def test_filter_by_min_importance(self):
|
|
record_event("task_posted", "Low", importance=0.1)
|
|
record_event("task_failed", "High", importance=0.9)
|
|
events = get_events(min_importance=0.5)
|
|
assert len(events) == 1
|
|
assert events[0].importance >= 0.5
|
|
|
|
def test_limit(self):
|
|
for i in range(10):
|
|
record_event("task_posted", f"ev{i}")
|
|
events = get_events(limit=3)
|
|
assert len(events) == 3
|
|
|
|
def test_order_by_created_desc(self):
|
|
record_event("task_posted", "first", task_id="ord1")
|
|
record_event("task_posted", "second", task_id="ord2")
|
|
events = get_events()
|
|
# Most recent first
|
|
assert events[0].task_id == "ord2"
|
|
|
|
def test_combined_filters(self):
|
|
record_event("task_failed", "A", agent_id="a1", task_id="t1", importance=0.9)
|
|
record_event("task_posted", "B", agent_id="a1", task_id="t2", importance=0.4)
|
|
record_event("task_failed", "C", agent_id="a2", task_id="t3", importance=0.9)
|
|
events = get_events(event_type="task_failed", agent_id="a1", min_importance=0.5)
|
|
assert len(events) == 1
|
|
assert events[0].task_id == "t1"
|
|
|
|
|
|
# ── count_events ──────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestCountEvents:
|
|
def test_empty(self):
|
|
assert count_events() == 0
|
|
|
|
def test_total(self):
|
|
record_event("task_posted", "A")
|
|
record_event("task_failed", "B")
|
|
assert count_events() == 2
|
|
|
|
def test_by_type(self):
|
|
record_event("task_posted", "A")
|
|
record_event("task_posted", "B")
|
|
record_event("task_failed", "C")
|
|
assert count_events("task_posted") == 2
|
|
assert count_events("task_failed") == 1
|
|
assert count_events("task_completed") == 0
|
|
|
|
|
|
# ── store_memory ──────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestStoreMemory:
|
|
def test_basic_store(self):
|
|
mid = store_memory("pattern", "system", "Test insight")
|
|
assert isinstance(mid, str)
|
|
assert len(mid) > 0
|
|
|
|
def test_returns_unique_ids(self):
|
|
id1 = store_memory("pattern", "a", "X")
|
|
id2 = store_memory("pattern", "b", "Y")
|
|
assert id1 != id2
|
|
|
|
def test_with_all_params(self):
|
|
store_memory(
|
|
"anomaly",
|
|
"agent-1",
|
|
"Odd pattern",
|
|
confidence=0.9,
|
|
source_events=10,
|
|
expires_at="2026-12-31",
|
|
)
|
|
mems = get_memories(subject="agent-1")
|
|
assert len(mems) == 1
|
|
assert mems[0].confidence == 0.9
|
|
assert mems[0].source_events == 10
|
|
assert mems[0].expires_at == "2026-12-31"
|
|
|
|
def test_default_values(self):
|
|
store_memory("insight", "sys", "Default test")
|
|
mems = get_memories(subject="sys")
|
|
assert mems[0].confidence == 0.5
|
|
assert mems[0].source_events == 0
|
|
assert mems[0].expires_at is None
|
|
|
|
|
|
# ── get_memories ──────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestGetMemories:
|
|
def test_empty(self):
|
|
assert get_memories() == []
|
|
|
|
def test_filter_by_type(self):
|
|
store_memory("pattern", "a", "X")
|
|
store_memory("anomaly", "a", "Y")
|
|
mems = get_memories(memory_type="pattern")
|
|
assert len(mems) == 1
|
|
assert mems[0].memory_type == "pattern"
|
|
|
|
def test_filter_by_subject(self):
|
|
store_memory("pattern", "a", "X")
|
|
store_memory("pattern", "b", "Y")
|
|
mems = get_memories(subject="a")
|
|
assert len(mems) == 1
|
|
|
|
def test_filter_by_min_confidence(self):
|
|
store_memory("pattern", "a", "Low", confidence=0.2)
|
|
store_memory("pattern", "b", "High", confidence=0.9)
|
|
mems = get_memories(min_confidence=0.5)
|
|
assert len(mems) == 1
|
|
assert mems[0].content == "High"
|
|
|
|
def test_limit(self):
|
|
for i in range(10):
|
|
store_memory("pattern", "a", f"M{i}")
|
|
mems = get_memories(limit=3)
|
|
assert len(mems) == 3
|
|
|
|
def test_combined_filters(self):
|
|
store_memory("pattern", "a", "Target", confidence=0.9)
|
|
store_memory("anomaly", "a", "Wrong type", confidence=0.9)
|
|
store_memory("pattern", "b", "Wrong subject", confidence=0.9)
|
|
store_memory("pattern", "a", "Low conf", confidence=0.1)
|
|
mems = get_memories(memory_type="pattern", subject="a", min_confidence=0.5)
|
|
assert len(mems) == 1
|
|
assert mems[0].content == "Target"
|
|
|
|
|
|
# ── count_memories ────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestCountMemories:
|
|
def test_empty(self):
|
|
assert count_memories() == 0
|
|
|
|
def test_total(self):
|
|
store_memory("pattern", "a", "X")
|
|
store_memory("anomaly", "b", "Y")
|
|
assert count_memories() == 2
|
|
|
|
def test_by_type(self):
|
|
store_memory("pattern", "a", "X")
|
|
store_memory("pattern", "b", "Y")
|
|
store_memory("anomaly", "c", "Z")
|
|
assert count_memories("pattern") == 2
|
|
assert count_memories("anomaly") == 1
|
|
assert count_memories("insight") == 0
|