Files
Timmy-time-dashboard/tests/spark/test_memory.py
Kimi Agent 1f1bc222e4
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
[kimi] test: add comprehensive tests for spark modules (#659) (#695)
2026-03-21 13:32:53 +00:00

390 lines
13 KiB
Python

"""Comprehensive tests for spark.memory module.
Covers:
- SparkEvent / SparkMemory dataclasses
- _get_conn (schema creation, WAL, busy timeout, idempotent indexes)
- score_importance (all event types, boosts, edge cases)
- record_event (auto-importance, explicit importance, invalid JSON, swarm bridge)
- get_events (all filters, ordering, limit)
- count_events (total, by type)
- store_memory (with/without expiry)
- get_memories (all filters)
- count_memories (total, by type)
"""
import json
import pytest
from spark.memory import (
IMPORTANCE_HIGH,
IMPORTANCE_LOW,
IMPORTANCE_MEDIUM,
SparkEvent,
SparkMemory,
_get_conn,
count_events,
count_memories,
get_events,
get_memories,
record_event,
score_importance,
store_memory,
)
# ── Constants ─────────────────────────────────────────────────────────────
class TestConstants:
def test_importance_ordering(self):
assert IMPORTANCE_LOW < IMPORTANCE_MEDIUM < IMPORTANCE_HIGH
# ── Dataclasses ───────────────────────────────────────────────────────────
class TestSparkEventDataclass:
def test_all_fields(self):
ev = SparkEvent(
id="1",
event_type="task_posted",
agent_id="a1",
task_id="t1",
description="Test",
data="{}",
importance=0.5,
created_at="2026-01-01",
)
assert ev.event_type == "task_posted"
assert ev.agent_id == "a1"
def test_nullable_fields(self):
ev = SparkEvent(
id="2",
event_type="task_posted",
agent_id=None,
task_id=None,
description="",
data="{}",
importance=0.5,
created_at="2026-01-01",
)
assert ev.agent_id is None
assert ev.task_id is None
class TestSparkMemoryDataclass:
def test_all_fields(self):
mem = SparkMemory(
id="1",
memory_type="pattern",
subject="system",
content="Test insight",
confidence=0.8,
source_events=5,
created_at="2026-01-01",
expires_at="2026-12-31",
)
assert mem.memory_type == "pattern"
assert mem.expires_at == "2026-12-31"
def test_nullable_expires(self):
mem = SparkMemory(
id="2",
memory_type="anomaly",
subject="agent-1",
content="Odd behavior",
confidence=0.6,
source_events=3,
created_at="2026-01-01",
expires_at=None,
)
assert mem.expires_at is None
# ── _get_conn ─────────────────────────────────────────────────────────────
class TestGetConn:
def test_creates_tables(self):
with _get_conn() as conn:
tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
names = {r["name"] for r in tables}
assert "spark_events" in names
assert "spark_memories" in names
def test_wal_mode(self):
with _get_conn() as conn:
mode = conn.execute("PRAGMA journal_mode").fetchone()[0]
assert mode == "wal"
def test_busy_timeout(self):
with _get_conn() as conn:
timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0]
assert timeout == 5000
def test_idempotent(self):
# Calling _get_conn twice should not raise
with _get_conn():
pass
with _get_conn():
pass
# ── score_importance ──────────────────────────────────────────────────────
class TestScoreImportance:
@pytest.mark.parametrize(
"event_type,expected_min,expected_max",
[
("task_posted", 0.3, 0.5),
("bid_submitted", 0.1, 0.3),
("task_assigned", 0.4, 0.6),
("task_completed", 0.5, 0.7),
("task_failed", 0.9, 1.0),
("agent_joined", 0.4, 0.6),
("prediction_result", 0.6, 0.8),
],
)
def test_base_scores(self, event_type, expected_min, expected_max):
score = score_importance(event_type, {})
assert expected_min <= score <= expected_max
def test_unknown_event_default(self):
assert score_importance("never_heard_of_this", {}) == 0.5
def test_failure_boost(self):
score = score_importance("task_failed", {})
assert score == 1.0
def test_high_bid_boost(self):
low = score_importance("bid_submitted", {"bid_sats": 10})
high = score_importance("bid_submitted", {"bid_sats": 100})
assert high > low
assert high <= 1.0
def test_high_bid_on_failure(self):
score = score_importance("task_failed", {"bid_sats": 100})
assert score == 1.0 # capped at 1.0
def test_score_always_rounded(self):
score = score_importance("bid_submitted", {"bid_sats": 100})
assert score == round(score, 2)
# ── record_event ──────────────────────────────────────────────────────────
class TestRecordEvent:
def test_basic_record(self):
eid = record_event("task_posted", "New task", task_id="t1")
assert isinstance(eid, str)
assert len(eid) > 0
def test_auto_importance(self):
record_event("task_failed", "Failed", task_id="t-auto")
events = get_events(task_id="t-auto")
assert events[0].importance >= 0.9
def test_explicit_importance(self):
record_event("task_posted", "Custom", task_id="t-expl", importance=0.1)
events = get_events(task_id="t-expl")
assert events[0].importance == 0.1
def test_with_agent_and_data(self):
data = json.dumps({"bid_sats": 42})
record_event("bid_submitted", "Bid", agent_id="a1", task_id="t-data", data=data)
events = get_events(task_id="t-data")
assert events[0].agent_id == "a1"
parsed = json.loads(events[0].data)
assert parsed["bid_sats"] == 42
def test_invalid_json_data_uses_default_importance(self):
record_event("task_posted", "Bad data", task_id="t-bad", data="not-json")
events = get_events(task_id="t-bad")
assert events[0].importance == 0.4 # base for task_posted
def test_returns_unique_ids(self):
id1 = record_event("task_posted", "A")
id2 = record_event("task_posted", "B")
assert id1 != id2
# ── get_events ────────────────────────────────────────────────────────────
class TestGetEvents:
def test_empty_db(self):
assert get_events() == []
def test_filter_by_type(self):
record_event("task_posted", "A")
record_event("task_completed", "B")
events = get_events(event_type="task_posted")
assert len(events) == 1
assert events[0].event_type == "task_posted"
def test_filter_by_agent(self):
record_event("task_posted", "A", agent_id="a1")
record_event("task_posted", "B", agent_id="a2")
events = get_events(agent_id="a1")
assert len(events) == 1
assert events[0].agent_id == "a1"
def test_filter_by_task(self):
record_event("task_posted", "A", task_id="t1")
record_event("task_posted", "B", task_id="t2")
events = get_events(task_id="t1")
assert len(events) == 1
def test_filter_by_min_importance(self):
record_event("task_posted", "Low", importance=0.1)
record_event("task_failed", "High", importance=0.9)
events = get_events(min_importance=0.5)
assert len(events) == 1
assert events[0].importance >= 0.5
def test_limit(self):
for i in range(10):
record_event("task_posted", f"ev{i}")
events = get_events(limit=3)
assert len(events) == 3
def test_order_by_created_desc(self):
record_event("task_posted", "first", task_id="ord1")
record_event("task_posted", "second", task_id="ord2")
events = get_events()
# Most recent first
assert events[0].task_id == "ord2"
def test_combined_filters(self):
record_event("task_failed", "A", agent_id="a1", task_id="t1", importance=0.9)
record_event("task_posted", "B", agent_id="a1", task_id="t2", importance=0.4)
record_event("task_failed", "C", agent_id="a2", task_id="t3", importance=0.9)
events = get_events(event_type="task_failed", agent_id="a1", min_importance=0.5)
assert len(events) == 1
assert events[0].task_id == "t1"
# ── count_events ──────────────────────────────────────────────────────────
class TestCountEvents:
def test_empty(self):
assert count_events() == 0
def test_total(self):
record_event("task_posted", "A")
record_event("task_failed", "B")
assert count_events() == 2
def test_by_type(self):
record_event("task_posted", "A")
record_event("task_posted", "B")
record_event("task_failed", "C")
assert count_events("task_posted") == 2
assert count_events("task_failed") == 1
assert count_events("task_completed") == 0
# ── store_memory ──────────────────────────────────────────────────────────
class TestStoreMemory:
def test_basic_store(self):
mid = store_memory("pattern", "system", "Test insight")
assert isinstance(mid, str)
assert len(mid) > 0
def test_returns_unique_ids(self):
id1 = store_memory("pattern", "a", "X")
id2 = store_memory("pattern", "b", "Y")
assert id1 != id2
def test_with_all_params(self):
store_memory(
"anomaly",
"agent-1",
"Odd pattern",
confidence=0.9,
source_events=10,
expires_at="2026-12-31",
)
mems = get_memories(subject="agent-1")
assert len(mems) == 1
assert mems[0].confidence == 0.9
assert mems[0].source_events == 10
assert mems[0].expires_at == "2026-12-31"
def test_default_values(self):
store_memory("insight", "sys", "Default test")
mems = get_memories(subject="sys")
assert mems[0].confidence == 0.5
assert mems[0].source_events == 0
assert mems[0].expires_at is None
# ── get_memories ──────────────────────────────────────────────────────────
class TestGetMemories:
def test_empty(self):
assert get_memories() == []
def test_filter_by_type(self):
store_memory("pattern", "a", "X")
store_memory("anomaly", "a", "Y")
mems = get_memories(memory_type="pattern")
assert len(mems) == 1
assert mems[0].memory_type == "pattern"
def test_filter_by_subject(self):
store_memory("pattern", "a", "X")
store_memory("pattern", "b", "Y")
mems = get_memories(subject="a")
assert len(mems) == 1
def test_filter_by_min_confidence(self):
store_memory("pattern", "a", "Low", confidence=0.2)
store_memory("pattern", "b", "High", confidence=0.9)
mems = get_memories(min_confidence=0.5)
assert len(mems) == 1
assert mems[0].content == "High"
def test_limit(self):
for i in range(10):
store_memory("pattern", "a", f"M{i}")
mems = get_memories(limit=3)
assert len(mems) == 3
def test_combined_filters(self):
store_memory("pattern", "a", "Target", confidence=0.9)
store_memory("anomaly", "a", "Wrong type", confidence=0.9)
store_memory("pattern", "b", "Wrong subject", confidence=0.9)
store_memory("pattern", "a", "Low conf", confidence=0.1)
mems = get_memories(memory_type="pattern", subject="a", min_confidence=0.5)
assert len(mems) == 1
assert mems[0].content == "Target"
# ── count_memories ────────────────────────────────────────────────────────
class TestCountMemories:
def test_empty(self):
assert count_memories() == 0
def test_total(self):
store_memory("pattern", "a", "X")
store_memory("anomaly", "b", "Y")
assert count_memories() == 2
def test_by_type(self):
store_memory("pattern", "a", "X")
store_memory("pattern", "b", "Y")
store_memory("anomaly", "c", "Z")
assert count_memories("pattern") == 2
assert count_memories("anomaly") == 1
assert count_memories("insight") == 0